mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-09-11 09:45:55 +03:00
Make Grenade fast
Changes shapes to get rid of the Vector, all data is now held in contiguous memory. Add fast c implementations for pooling layers. Now does mnist on my laptop in 12 minutes.
This commit is contained in:
parent
1ec65a414f
commit
6417151620
@ -7,28 +7,20 @@ import Numeric.LinearAlgebra
|
||||
|
||||
main :: IO ()
|
||||
main = defaultMain [
|
||||
bgroup "im2col" [ bench "im2col 3x4" $ whnf (im2colUnsafe 2 2 1 1) ((3><4) [1..])
|
||||
, bench "im2col 28x28" $ whnf (im2colUnsafe 5 5 1 1) ((28><28) [1..])
|
||||
, bench "im2col 100x100" $ whnf (im2colUnsafe 10 10 1 1) ((100><100) [1..])
|
||||
bgroup "im2col" [ bench "im2col 3x4" $ whnf (im2col 2 2 1 1) ((3><4) [1..])
|
||||
, bench "im2col 28x28" $ whnf (im2col 5 5 1 1) ((28><28) [1..])
|
||||
, bench "im2col 100x100" $ whnf (im2col 10 10 1 1) ((100><100) [1..])
|
||||
]
|
||||
, bgroup "im2col_c" [ bench "im2col_c 3x4" $ whnf (im2col 2 2 1 1) ((3><4) [1..])
|
||||
, bench "im2col_c 28x28" $ whnf (im2col 5 5 1 1) ((28><28) [1..])
|
||||
, bench "im2col_c 100x100" $ whnf (im2col 10 10 1 1) ((100><100) [1..])
|
||||
]
|
||||
, bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2imUnsafe 2 2 1 1 3 4) ((6><4) [1..])
|
||||
, bench "col2im 28x28" $ whnf (col2imUnsafe 5 5 1 1 28 28) ((576><25) [1..])
|
||||
, bench "col2im 100x100" $ whnf (col2imUnsafe 10 10 1 1 100 100) ((8281><100) [1..])
|
||||
, bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2im 2 2 1 1 3 4) ((6><4) [1..])
|
||||
, bench "col2im 28x28" $ whnf (col2im 5 5 1 1 28 28) ((576><25) [1..])
|
||||
, bench "col2im 100x100" $ whnf (col2im 10 10 1 1 100 100) ((8281><100) [1..])
|
||||
]
|
||||
, bgroup "col2im_c" [ bench "col2im_c 3x4" $ whnf (col2im 2 2 1 1 3 4) ((6><4) [1..])
|
||||
, bench "col2im_c 28x28" $ whnf (col2im 5 5 1 1 28 28) ((576><25) [1..])
|
||||
, bench "col2im_c 100x100" $ whnf (col2im 10 10 1 1 100 100) ((8281><100) [1..])
|
||||
]
|
||||
, bgroup "poolfw" [ bench "poolforwards 3x4" $ whnf (poolForward 2 2 1 1 2 3) ((3><4) [1..])
|
||||
, bench "poolforwards 28x28" $ whnf (poolForward 5 5 1 1 4 24) ((28><28) [1..])
|
||||
, bench "poolforwards 100x100" $ whnf (poolForward 10 10 1 1 91 91) ((100><100) [1..])
|
||||
, bgroup "poolfw" [ bench "poolforwards 3x4" $ whnf (poolForward 1 3 4 2 2 1 1) ((3><4) [1..])
|
||||
, bench "poolforwards 28x28" $ whnf (poolForward 1 28 28 5 5 1 1) ((28><28) [1..])
|
||||
, bench "poolforwards 100x100" $ whnf (poolForward 1 100 100 10 10 1 1) ((100><100) [1..])
|
||||
]
|
||||
, bgroup "poolbw" [ bench "poolbackwards 3x4" $ whnf (poolBackward 2 2 1 1 ((3><4) [1..])) ((2><3) [1..])
|
||||
, bench "poolbackwards 28x28" $ whnf (poolBackward 5 5 1 1 ((28><28) [1..])) ((24><24) [1..])
|
||||
, bench "poolbackwards 100x100" $ whnf (poolBackward 10 10 1 1 ((100><100) [1..])) ((91><91) [1..])
|
||||
, bgroup "poolbw" [ bench "poolbackwards 3x4" $ whnf (poolBackward 1 3 4 2 2 1 1 ((3><4) [1..])) ((2><3) [1..])
|
||||
, bench "poolbackwards 28x28" $ whnf (poolBackward 1 28 28 5 5 1 1 ((28><28) [1..])) ((24><24) [1..])
|
||||
, bench "poolbackwards 100x100" $ whnf (poolBackward 1 100 100 10 10 1 1 ((100><100) [1..])) ((91><91) [1..])
|
||||
]
|
||||
]
|
||||
|
@ -10,8 +10,6 @@ void im2col_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
double* data_col) {
|
||||
|
||||
data_im += dataOffset;
|
||||
const int output_h = (height - kernel_h) / stride_h + 1;
|
||||
const int output_w = (width - kernel_w) / stride_w + 1;
|
||||
const int channel_size = height * width;
|
||||
|
||||
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
|
||||
@ -36,12 +34,8 @@ void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
||||
|
||||
memset(data_im, 0, height * width * channels * sizeof(double));
|
||||
data_col += dataOffset;
|
||||
const int output_h = (height - kernel_h) / stride_h + 1;
|
||||
const int output_w = (width - kernel_w) / stride_w + 1;
|
||||
const int channel_size = height * width;
|
||||
|
||||
int offsetRow = 0;
|
||||
int offsetColumn = 0;
|
||||
const int channel_size = height * width;
|
||||
|
||||
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
|
||||
for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) {
|
||||
@ -57,3 +51,84 @@ void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline int max ( int a, int b ) { return a > b ? a : b; }
|
||||
|
||||
void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_pooled) {
|
||||
|
||||
data_im += dataOffset;
|
||||
|
||||
const int channel_size = height * width;
|
||||
|
||||
for (int channel = 0; channel < channels; channel++) {
|
||||
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
|
||||
for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) {
|
||||
// Start with the value in 0,0
|
||||
int max_value = data_im[fitting_height * width + fitting_width + channel_size * channel];
|
||||
// Initial row, skipping the corner we've done
|
||||
for (int kernel_col = 1; kernel_col < kernel_w; kernel_col++) {
|
||||
int input_row = fitting_height;
|
||||
int input_col = fitting_width + kernel_col;
|
||||
max_value = max ( max_value, data_im[input_row * width + input_col + channel_size * channel] );
|
||||
}
|
||||
// The remaining rows
|
||||
for (int kernel_row = 1; kernel_row < kernel_h; kernel_row++) {
|
||||
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
||||
int input_row = fitting_height + kernel_row;
|
||||
int input_col = fitting_width + kernel_col;
|
||||
max_value = max ( max_value, data_im[input_row * width + input_col + channel_size * channel] );
|
||||
}
|
||||
}
|
||||
*(data_pooled++) = max_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void pool_backwards_cpu(const double* data_im, int data_im_offset,
|
||||
const double* data_pooled, int data_pooled_offset,
|
||||
const int channels, const int height, const int width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
double* data_backgrad ) {
|
||||
|
||||
data_im += data_im_offset;
|
||||
data_pooled += data_pooled_offset;
|
||||
memset(data_backgrad, 0, height * width * channels * sizeof(double));
|
||||
|
||||
const int channel_size = height * width;
|
||||
|
||||
for (int channel = 0; channel < channels; channel++) {
|
||||
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
|
||||
for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) {
|
||||
int max_index = fitting_height * width + fitting_width + channel_size * channel;
|
||||
int max_value = data_im[max_index];
|
||||
for (int kernel_col = 1; kernel_col < kernel_w; kernel_col++) {
|
||||
int input_row = fitting_height;
|
||||
int input_col = fitting_width + kernel_col;
|
||||
int data_index = input_row * width + input_col + channel_size * channel;
|
||||
int data_value = data_im[data_index];
|
||||
if ( data_value > max_value ) {
|
||||
max_value = data_value;
|
||||
max_index = data_index;
|
||||
}
|
||||
}
|
||||
for (int kernel_row = 1; kernel_row < kernel_h; kernel_row++) {
|
||||
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
||||
int input_row = fitting_height + kernel_row;
|
||||
int input_col = fitting_width + kernel_col;
|
||||
int data_index = input_row * width + input_col + channel_size * channel;
|
||||
int data_value = data_im[data_index];
|
||||
if ( data_value > max_value ) {
|
||||
max_value = data_value;
|
||||
max_index = data_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
data_backgrad[max_index] += *(data_pooled++);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,3 +11,14 @@ void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_im);
|
||||
|
||||
void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_pooled);
|
||||
|
||||
void pool_backwards_cpu(const double* data_im, int data_im_offset,
|
||||
const double* data_pooled, int data_pooled_offset,
|
||||
const int channels, const int height, const int width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
double* data_backgrad );
|
||||
|
@ -60,6 +60,8 @@ library
|
||||
includes: cbits/im2col.h
|
||||
c-sources: cbits/im2col.c
|
||||
|
||||
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
|
||||
|
||||
executable feedforward
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/feedforward.hs
|
||||
|
@ -19,10 +19,10 @@ module Grenade.Core.Shape (
|
||||
) where
|
||||
|
||||
import Data.Singletons.TypeLits
|
||||
import GHC.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
import Grenade.Core.Vector
|
||||
|
||||
-- | The current shapes we accept.
|
||||
-- at the moment this is just one, two, and three dimensional
|
||||
@ -35,34 +35,39 @@ data Shape =
|
||||
instance Num (S' x) where
|
||||
(+) (S1D' x) (S1D' y) = S1D' (x + y)
|
||||
(+) (S2D' x) (S2D' y) = S2D' (x + y)
|
||||
(+) (S3D' x) (S3D' y) = S3D' (vectorZip (+) x y)
|
||||
(+) (S3D' x) (S3D' y) = S3D' (x + y)
|
||||
(+) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
(-) (S1D' x) (S1D' y) = S1D' (x - y)
|
||||
(-) (S2D' x) (S2D' y) = S2D' (x - y)
|
||||
(-) (S3D' x) (S3D' y) = S3D' (vectorZip (-) x y)
|
||||
(-) (S3D' x) (S3D' y) = S3D' (x - y)
|
||||
(-) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
(*) (S1D' x) (S1D' y) = S1D' (x * y)
|
||||
(*) (S2D' x) (S2D' y) = S2D' (x * y)
|
||||
(*) (S3D' x) (S3D' y) = S3D' (vectorZip (*) x y)
|
||||
(*) (S3D' x) (S3D' y) = S3D' (x * y)
|
||||
(*) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
abs (S1D' x) = S1D' (abs x)
|
||||
abs (S2D' x) = S2D' (abs x)
|
||||
abs (S3D' x) = S3D' (fmap abs x)
|
||||
abs (S3D' x) = S3D' (abs x)
|
||||
|
||||
signum (S1D' x) = S1D' (signum x)
|
||||
signum (S2D' x) = S2D' (signum x)
|
||||
signum (S3D' x) = S3D' (fmap signum x)
|
||||
signum (S3D' x) = S3D' (signum x)
|
||||
|
||||
fromInteger _ = error "Unimplemented: fromInteger on Shape"
|
||||
|
||||
-- | Given a Shape n, these are the possible data structures with that shape.
|
||||
-- All shapes are held in contiguous memory.
|
||||
-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
|
||||
data S' (n :: Shape) where
|
||||
S1D' :: (KnownNat o) => R o -> S' ('D1 o)
|
||||
S2D' :: (KnownNat rows, KnownNat columns) => L rows columns -> S' ('D2 rows columns)
|
||||
S3D' :: (KnownNat rows, KnownNat columns, KnownNat depth) => Vector depth (L rows columns) -> S' ('D3 rows columns depth)
|
||||
S1D' :: ( KnownNat o ) => R o -> S' ('D1 o)
|
||||
S2D' :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S' ('D2 rows columns)
|
||||
S3D' :: ( KnownNat rows
|
||||
, KnownNat columns
|
||||
, KnownNat depth
|
||||
, KnownNat (rows * depth)) => L (rows * depth) columns -> S' ('D3 rows columns depth)
|
||||
|
||||
instance Show (S' n) where
|
||||
show (S1D' a) = "S1D' " ++ show a
|
||||
|
@ -30,7 +30,6 @@ import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
-- | A convolution layer for a neural network.
|
||||
@ -145,6 +144,7 @@ instance ( KnownNat kernelRows
|
||||
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||
, ((outputCols - 1) * strideCols) ~ (inputCols - kernelCols)
|
||||
, KnownNat (kernelRows * kernelCols * 1)
|
||||
, KnownNat (outputRows * filters)
|
||||
) => Layer (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where
|
||||
runForwards (Convolution kernel _) (S2D' input) =
|
||||
let ex = extract input
|
||||
@ -158,8 +158,8 @@ instance ( KnownNat kernelRows
|
||||
c = im2col kx ky sx sy ex
|
||||
mt = c LA.<> ek
|
||||
r = col2vid 1 1 1 1 ox oy mt
|
||||
rs = fmap (fromJust . create) r
|
||||
in S3D' $ mkVector rs
|
||||
rs = fromJust . create $ r
|
||||
in S3D' rs
|
||||
|
||||
runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
|
||||
let ex = extract input
|
||||
@ -174,7 +174,7 @@ instance ( KnownNat kernelRows
|
||||
|
||||
c = im2col kx ky sx sy ex
|
||||
|
||||
eo = vecToList $ fmap extract dEdy
|
||||
eo = extract dEdy
|
||||
ek = extract kernel
|
||||
|
||||
vs = vid2col 1 1 1 1 ox oy eo
|
||||
@ -201,9 +201,10 @@ instance ( KnownNat kernelRows
|
||||
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||
, ((outputCols - 1) * strideCols) ~ (inputCols - kernelCols)
|
||||
, KnownNat (kernelRows * kernelCols * channels)
|
||||
, KnownNat (outputRows * filters)
|
||||
) => Layer (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where
|
||||
runForwards (Convolution kernel _) (S3D' input) =
|
||||
let ex = vecToList $ fmap extract input
|
||||
let ex = extract input
|
||||
ek = extract kernel
|
||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||
@ -217,10 +218,10 @@ instance ( KnownNat kernelRows
|
||||
c = vid2col kx ky sx sy ix iy ex
|
||||
mt = c LA.<> ek
|
||||
r = col2vid 1 1 1 1 ox oy mt
|
||||
rs = fmap (fromJust . create) r
|
||||
in S3D' $ mkVector rs
|
||||
rs = fromJust . create $ r
|
||||
in S3D' rs
|
||||
runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
|
||||
let ex = vecToList $ fmap extract input
|
||||
let ex = extract input
|
||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
@ -232,7 +233,7 @@ instance ( KnownNat kernelRows
|
||||
|
||||
c = vid2col kx ky sx sy ix iy ex
|
||||
|
||||
eo = vecToList $ fmap extract dEdy
|
||||
eo = extract dEdy
|
||||
ek = extract kernel
|
||||
|
||||
vs = vid2col 1 1 1 1 ox oy eo
|
||||
@ -242,4 +243,4 @@ instance ( KnownNat kernelRows
|
||||
dW = vs LA.<> tr ek
|
||||
|
||||
xW = col2vid kx ky sx sy ix iy dW
|
||||
in (Convolution' kN, S3D' . mkVector . fmap (fromJust . create) $ xW)
|
||||
in (Convolution' kN, S3D' . fromJust . create $ xW)
|
||||
|
@ -5,20 +5,19 @@
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Grenade.Layers.Flatten (
|
||||
FlattenLayer (..)
|
||||
) where
|
||||
|
||||
import Data.Proxy
|
||||
import Data.Singletons.TypeLits
|
||||
import GHC.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra.Static
|
||||
import Numeric.LinearAlgebra.Data as LA (flatten, toList, takesV, reshape, vjoin)
|
||||
import Numeric.LinearAlgebra.Data as LA (flatten, toList)
|
||||
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Network
|
||||
|
||||
@ -31,21 +30,10 @@ instance UpdateLayer FlattenLayer where
|
||||
createRandom = return FlattenLayer
|
||||
|
||||
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * z)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
||||
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
|
||||
runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
|
||||
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
||||
runForwards _ (S3D' y) = S1D' . raiseShapeError . create . vjoin . vecToList . fmap (flatten . extract) $ y
|
||||
runBackwards _ _ (S1D' o) =
|
||||
let x' = fromIntegral $ natVal (Proxy :: Proxy x)
|
||||
y' = fromIntegral $ natVal (Proxy :: Proxy y)
|
||||
z' = fromIntegral $ natVal (Proxy :: Proxy z)
|
||||
vecs = takesV (replicate z' (x' * y')) (extract o)
|
||||
ls = fmap (raiseShapeError . create . reshape y') vecs
|
||||
ls' = mkVector ls :: Vector z (L x y)
|
||||
in ((), S3D' ls')
|
||||
|
||||
raiseShapeError :: Maybe a -> a
|
||||
raiseShapeError (Just x) = x
|
||||
raiseShapeError Nothing = error "Static shape creation from Flatten layer produced the wrong result"
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
||||
runForwards _ (S3D' y) = S1D' . fromList . toList . flatten . extract $ y
|
||||
runBackwards _ _ (S1D' y) = ((), S3D' . fromList . toList . unwrap $ y)
|
||||
|
@ -1,79 +1,23 @@
|
||||
{-# LANGUAGE ForeignFunctionInterface #-}
|
||||
module Grenade.Layers.Internal.Convolution (
|
||||
col2vidUnsafe
|
||||
, col2imUnsafe
|
||||
, vid2colUnsafe
|
||||
, im2colUnsafe
|
||||
, im2col
|
||||
im2col
|
||||
, col2im
|
||||
, col2vid
|
||||
, vid2col
|
||||
, fittingStarts
|
||||
, unsafeModifyMatrix
|
||||
) where
|
||||
|
||||
import Control.Monad.ST ( ST, runST )
|
||||
|
||||
import Data.STRef ( newSTRef, modifySTRef', writeSTRef, readSTRef )
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Traversable ( forM )
|
||||
|
||||
import Foreign ( mallocForeignPtrArray0, withForeignPtr )
|
||||
import Foreign.Ptr ( Ptr )
|
||||
import Foreign.Storable( Storable )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import Numeric.LinearAlgebra ( Matrix, flatten, rows, cols )
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
import System.IO.Unsafe ( unsafePerformIO )
|
||||
|
||||
-- This module provides provides im2col function and friends, ala caffe.
|
||||
--
|
||||
-- /* From Caffe */
|
||||
-- @
|
||||
-- void col2im_cpu(const Dtype* data_col, const int channels,
|
||||
-- const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
-- const int pad_h, const int pad_w,
|
||||
-- const int stride_h, const int stride_w,
|
||||
-- const int dilation_h, const int dilation_w,
|
||||
-- Dtype* data_im) {
|
||||
-- caffe_set(height * width * channels, Dtype(0), data_im);
|
||||
-- const int output_h = (height + 2 * pad_h -
|
||||
-- (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
-- const int output_w = (width + 2 * pad_w -
|
||||
-- (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
-- const int channel_size = height * width;
|
||||
-- for (int channel = channels; channel--; data_im += channel_size) {
|
||||
-- for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
|
||||
-- for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
||||
-- int input_row = -pad_h + kernel_row * dilation_h;
|
||||
-- for (int output_rows = output_h; output_rows; output_rows--) {
|
||||
-- if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
|
||||
-- data_col += output_w;
|
||||
-- } else {
|
||||
-- int input_col = -pad_w + kernel_col * dilation_w;
|
||||
-- for (int output_col = output_w; output_col; output_col--) {
|
||||
-- if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
|
||||
-- data_im[input_row * width + input_col] += *data_col;
|
||||
-- }
|
||||
-- data_col++;
|
||||
-- input_col += stride_w;
|
||||
-- }
|
||||
-- }
|
||||
-- input_row += stride_h;
|
||||
-- }
|
||||
-- }
|
||||
-- }
|
||||
-- }
|
||||
-- }
|
||||
-- @
|
||||
--
|
||||
|
||||
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2vid kernelRows kernelColumns strideRows strideColumns height width dataCol =
|
||||
let channels = cols dataCol `div` (kernelRows * kernelColumns)
|
||||
retMat = col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol
|
||||
in (\f -> subMatrix (f * height, 0) (height, width) retMat) <$> [0..channels -1]
|
||||
in col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol
|
||||
|
||||
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2im kernelRows kernelColumns strideRows strideColumns height width dataCol =
|
||||
@ -94,75 +38,14 @@ col2im_c channels height width kernelRows kernelColumns strideRows strideColumns
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
|
||||
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
||||
|
||||
foreign import ccall safe
|
||||
foreign import ccall unsafe
|
||||
col2im_cpu
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
||||
|
||||
-- | col2im function.
|
||||
--
|
||||
-- Takes a column patch, and reconstitutes it into a normal image.
|
||||
-- Does not do any bounds checking on the matrix, so should only
|
||||
-- be called once the sizes are ensured correct.
|
||||
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
|
||||
let columnMatrixRows = rows columnMatrix
|
||||
|
||||
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
||||
offsetR <- newSTRef 0
|
||||
offsetC <- newSTRef 0
|
||||
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \inputRow -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
|
||||
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix inputRow inputColumn)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef' offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
|
||||
|
||||
return dataIm
|
||||
|
||||
-- | col2vid function.
|
||||
--
|
||||
-- Takes a column patch image, and reconstitutes it into a normal image with multiple channels.
|
||||
-- Does not do any bounds checking on the matrix, so should only
|
||||
-- be called once the sizes are ensured correct.
|
||||
col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||
col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do
|
||||
let columnMatrixRows = rows columnMatrix
|
||||
let filters = cols columnMatrix `div` (kernelRows * kernelColumns)
|
||||
|
||||
forM [0 .. filters - 1] $ \iter -> do
|
||||
let offsetM = iter * (kernelRows * kernelColumns)
|
||||
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
||||
offsetR <- newSTRef 0
|
||||
offsetC <- newSTRef 0
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
||||
inputColumn <- newSTRef offsetM
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
|
||||
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
|
||||
ic <- readSTRef inputColumn
|
||||
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix ir ic)
|
||||
modifySTRef' inputColumn (+1)
|
||||
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef' offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
|
||||
|
||||
U.unsafeFreezeMatrix dataIm
|
||||
|
||||
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||
vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid =
|
||||
let channels = length dataVid
|
||||
in im2col_c channels height width kernelRows kernelColumns strideRows strideColumns (foldl1 (===) dataVid)
|
||||
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid =
|
||||
let channels = rows dataVid `div` height
|
||||
in im2col_c channels height width kernelRows kernelColumns strideRows strideColumns dataVid
|
||||
|
||||
|
||||
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
@ -190,89 +73,6 @@ im2col_c channels height width kernelRows kernelColumns strideRows strideColumns
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * kernelSize * channels)
|
||||
return $ U.matrixFromVector U.RowMajor numberOfPatches (kernelSize * channels) matVec
|
||||
|
||||
foreign import ccall safe
|
||||
foreign import ccall unsafe
|
||||
im2col_cpu
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
||||
|
||||
unsafeModifyMatrix :: (Storable t) => U.STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
|
||||
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
|
||||
{-# INLINE unsafeModifyMatrix #-}
|
||||
|
||||
|
||||
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
||||
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
||||
let rs = fittingStart nrows kernelrows steprows
|
||||
cs = fittingStart ncols kernelcols stepcolsh
|
||||
in concatMap ( \r -> fmap (\c -> (r , c)) cs ) rs
|
||||
|
||||
-- | Returns the starting sub vector which fit inside the larger vector for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStart :: Int -> Int -> Int -> [Int]
|
||||
fittingStart width kernel steps =
|
||||
let go left | left + kernel < width
|
||||
= left : go (left + steps)
|
||||
| left + kernel == width
|
||||
= [left]
|
||||
| otherwise
|
||||
= []
|
||||
in go 0
|
||||
|
||||
|
||||
|
||||
|
||||
-- | Old functions (useful for sanity checking and benchmarking)
|
||||
|
||||
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||
vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
|
||||
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
|
||||
kernelSize = kernelRows * kernelColumns
|
||||
numberOfPatches = length starts
|
||||
channels = length dataVid
|
||||
|
||||
dataCol <- U.newUndefinedMatrix U.RowMajor numberOfPatches (channels * kernelSize)
|
||||
|
||||
offsetC <- newSTRef 0
|
||||
|
||||
forM_ dataVid $ \dataIm -> do
|
||||
inputRowRef <- newSTRef 0
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
|
||||
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.unsafeWriteMatrix dataCol inputRow (inputColumn + offsetC') (U.atM' dataIm kr kc)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
modifySTRef' inputRowRef (+1)
|
||||
|
||||
modifySTRef' offsetC (+ kernelSize)
|
||||
|
||||
return dataCol
|
||||
|
||||
im2colUnsafe :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatrix $ do
|
||||
let starts = fittingStarts (rows dataIm) kernelRows striderows (cols dataIm) kernelColumns stridecols
|
||||
kernelSize = kernelRows * kernelColumns
|
||||
numberOfPatches = length starts
|
||||
|
||||
dataCol <- U.newUndefinedMatrix U.RowMajor numberOfPatches kernelSize
|
||||
|
||||
inputRowRef <- newSTRef 0
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
|
||||
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.unsafeWriteMatrix dataCol inputRow inputColumn (U.atM' dataIm kr kc)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
modifySTRef' inputRowRef (+1)
|
||||
|
||||
return dataCol
|
||||
|
||||
|
||||
{-# ANN module "HLint: ignore Reduce duplication" #-}
|
||||
|
@ -1,69 +1,55 @@
|
||||
{-# LANGUAGE ForeignFunctionInterface #-}
|
||||
module Grenade.Layers.Internal.Pooling (
|
||||
poolForward
|
||||
, poolBackward
|
||||
, poolForwardList
|
||||
, poolBackwardList
|
||||
) where
|
||||
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Function ( on )
|
||||
import Data.List ( maximumBy )
|
||||
import Foreign ( mallocForeignPtrArray0, withForeignPtr )
|
||||
import Foreign.Ptr ( Ptr )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import Numeric.LinearAlgebra ( Matrix , flatten )
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
import System.IO.Unsafe ( unsafePerformIO )
|
||||
|
||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForward nrows ncols srows scols outputRows outputCols m =
|
||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||
in poolForwardFit starts nrows ncols outputRows outputCols m
|
||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForward channels height width kernelRows kernelColumns strideRows strideColumns dataIm =
|
||||
let vec = flatten dataIm
|
||||
rowOut = (height - kernelRows) `div` strideRows + 1
|
||||
colOut = (width - kernelColumns) `div` strideColumns + 1
|
||||
numberOfPatches = rowOut * colOut
|
||||
in unsafePerformIO $ do
|
||||
outPtr <- mallocForeignPtrArray0 (numberOfPatches * channels)
|
||||
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
|
||||
|
||||
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
|
||||
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||
let starts = fittingStarts inRows nrows srows inCols ncols scols
|
||||
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
|
||||
withForeignPtr inPtr $ \inPtr' ->
|
||||
withForeignPtr outPtr $ \outPtr' ->
|
||||
pool_forwards_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
|
||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForwardFit starts nrows ncols _ outputCols m =
|
||||
let els = fmap (\start -> unsafeMaxElementSubmatrix start (nrows, ncols) m) starts
|
||||
in LA.matrix outputCols els
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * channels)
|
||||
return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
|
||||
foreign import ccall unsafe
|
||||
pool_forwards_cpu
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
||||
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
|
||||
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||
let starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataCol dataGrad =
|
||||
let vecIm = flatten dataCol
|
||||
vecGrad = flatten dataGrad
|
||||
in unsafePerformIO $ do
|
||||
outPtr <- mallocForeignPtrArray0 (height * width * channels)
|
||||
let (imPtr, imOffset, _) = U.unsafeToForeignPtr vecIm
|
||||
let (gradPtr, gradOffset, _) = U.unsafeToForeignPtr vecGrad
|
||||
|
||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
gradCol = cols gradientMatrix
|
||||
extent = (krows, kcols)
|
||||
withForeignPtr imPtr $ \imPtr' ->
|
||||
withForeignPtr gradPtr $ \gradPtr' ->
|
||||
withForeignPtr outPtr $ \outPtr' ->
|
||||
pool_backwards_cpu imPtr' imOffset gradPtr' gradOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
|
||||
retM <- U.newMatrix 0 inRows inCols
|
||||
|
||||
forM_ (zip [0..] starts) $ \(ix, start) -> do
|
||||
let loc = unsafeMaxIndexSubMatrix start extent inputMatrix
|
||||
uncurry (unsafeModifyMatrix retM) loc ((+) $ uncurry (U.atM' gradientMatrix) $ divMod ix gradCol)
|
||||
|
||||
return retM
|
||||
|
||||
unsafeMaxElementSubmatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> Double
|
||||
unsafeMaxElementSubmatrix starts extent m = uncurry (U.atM' m) $ unsafeMaxIndexSubMatrix starts extent m
|
||||
|
||||
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =
|
||||
let mrows = [startRow .. startRow + extentRow - 1]
|
||||
mcols = [startCol .. startCol + extentCold - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
|
||||
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
||||
|
||||
foreign import ccall unsafe
|
||||
pool_backwards_cpu
|
||||
:: Ptr Double -> Int -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
@ -12,7 +12,6 @@ module Grenade.Layers.Logit (
|
||||
|
||||
import Data.Singletons.TypeLits
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Core.Shape
|
||||
|
||||
-- | A Logit layer.
|
||||
@ -36,8 +35,8 @@ instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (fmap logistic y)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy))
|
||||
runForwards _ (S3D' y) = S3D' (logistic y)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (logistic' y * dEdy))
|
||||
|
||||
|
||||
logistic :: Floating a => a -> a
|
||||
|
@ -21,7 +21,6 @@ import GHC.TypeLits
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Layers.Internal.Pooling
|
||||
|
||||
import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows)
|
||||
@ -57,24 +56,26 @@ instance ( KnownNat kernelRows
|
||||
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||
runForwards Pooling (S2D' input) =
|
||||
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
||||
ex = extract input
|
||||
r = poolForward kx ky sx sy ox oy $ ex
|
||||
r = poolForward 1 height width kx ky sx sy ex
|
||||
rs = fromJust . create $ r
|
||||
in S2D' $ rs
|
||||
runBackwards Pooling (S2D' input) (S2D' dEdy) =
|
||||
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||
ex = extract input
|
||||
eo = extract dEdy
|
||||
vs = poolBackward kx ky sx sy ex eo
|
||||
vs = poolBackward 1 height width kx ky sx sy ex eo
|
||||
in ((), S2D' . fromJust . create $ vs)
|
||||
|
||||
|
||||
@ -87,6 +88,8 @@ instance ( KnownNat kernelRows
|
||||
, KnownNat inputColumns
|
||||
, KnownNat outputRows
|
||||
, KnownNat outputColumns
|
||||
, KnownNat channels
|
||||
, KnownNat (outputRows * channels)
|
||||
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
|
||||
@ -97,11 +100,10 @@ instance ( KnownNat kernelRows
|
||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
||||
ex = fmap extract input
|
||||
r = poolForwardList kx ky sx sy ix iy ox oy ex
|
||||
rs = fmap (fromJust . create) r
|
||||
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
|
||||
ex = extract input
|
||||
r = poolForward ch ix iy kx ky sx sy ex
|
||||
rs = fromJust . create $ r
|
||||
in S3D' rs
|
||||
runBackwards Pooling (S3D' input) (S3D' dEdy) =
|
||||
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
@ -110,8 +112,8 @@ instance ( KnownNat kernelRows
|
||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||
ex = fmap extract input
|
||||
eo = fmap extract dEdy
|
||||
ez = vectorZip (,) ex eo
|
||||
vs = poolBackwardList kx ky sx sy ix iy ez
|
||||
in ((), S3D' . fmap (fromJust . create) $ vs)
|
||||
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
|
||||
ex = extract input
|
||||
eo = extract dEdy
|
||||
vs = poolBackward ch ix iy kx ky sx sy ex eo
|
||||
in ((), S3D' . fromJust . create $ vs)
|
||||
|
@ -10,7 +10,6 @@ module Grenade.Layers.Relu (
|
||||
) where
|
||||
|
||||
import GHC.TypeLits
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
@ -44,9 +43,9 @@ instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
|
||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (fmap relu y)
|
||||
runForwards _ (S3D' y) = S3D' (relu y)
|
||||
where
|
||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy))
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (relu' y * dEdy))
|
||||
where
|
||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
@ -10,7 +10,6 @@ module Grenade.Layers.Tanh (
|
||||
) where
|
||||
|
||||
import GHC.TypeLits
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
@ -33,8 +32,8 @@ instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (fmap tanh y)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy))
|
||||
runForwards _ (S3D' y) = S3D' (tanh y)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (tanh' y * dEdy))
|
||||
|
||||
tanh' :: (Floating a) => a -> a
|
||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
||||
|
@ -5,7 +5,6 @@
|
||||
module Test.Grenade.Layers.Convolution where
|
||||
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Vector as Grenade
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Layers.Convolution
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
@ -16,21 +15,6 @@ import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||
import Test.QuickCheck hiding ((><))
|
||||
|
||||
prop_im2col_no_stride = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
expected = (6><4)
|
||||
[ 1.0, 2.0, 5.0, 6.0
|
||||
, 2.0, 3.0, 6.0, 7.0
|
||||
, 3.0, 4.0, 7.0, 8.0
|
||||
, 5.0, 6.0, 9.0, 10.0
|
||||
, 6.0, 7.0, 10.0, 11.0
|
||||
, 7.0, 8.0, 11.0, 12.0 ]
|
||||
out = im2colUnsafe 2 2 1 1 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_c = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
@ -46,19 +30,6 @@ prop_im2col_c = once $
|
||||
in expected === out
|
||||
|
||||
prop_im2col_stride = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
expected = (4><4)
|
||||
[ 1.0, 2.0, 5.0, 6.0
|
||||
, 3.0, 4.0, 7.0, 8.0
|
||||
, 5.0, 6.0, 9.0, 10.0
|
||||
, 7.0, 8.0, 11.0, 12.0 ]
|
||||
out = im2colUnsafe 2 2 1 2 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_c_stride = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
@ -72,17 +43,6 @@ prop_im2col_c_stride = once $
|
||||
in expected === out
|
||||
|
||||
prop_im2col_other = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
expected = (2><6)
|
||||
[ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0
|
||||
, 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ]
|
||||
out = im2colUnsafe 3 2 1 2 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_c_other = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
@ -93,10 +53,6 @@ prop_im2col_c_other = once $
|
||||
out = im2col 3 2 1 2 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_bigger = once $
|
||||
let input = (7><7) [ 1.0 .. ]
|
||||
in im2colUnsafe 5 5 2 2 input === im2col 5 5 2 2 input
|
||||
|
||||
-- If there's no overlap (stride is the same size as the kernel)
|
||||
-- then col2im . im2col should be symmetric.
|
||||
prop_im2col_sym_on_same_stride = once $
|
||||
@ -104,20 +60,9 @@ prop_im2col_sym_on_same_stride = once $
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
out = col2im 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
|
||||
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
|
||||
in input === out
|
||||
|
||||
-- If there's no overlap (stride is the same size as the kernel)
|
||||
-- then col2im . im2col should be symmetric.
|
||||
prop_im2colunsafe_sym_on_same_stride = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
out = col2im 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
|
||||
in input === out
|
||||
|
||||
|
||||
-- If there is an overlap, then the gradient passed back should be
|
||||
-- the sum of the gradients across the filters.
|
||||
prop_im2col_col2im_additive = once $
|
||||
@ -129,7 +74,7 @@ prop_im2col_col2im_additive = once $
|
||||
[ 1.0, 2.0, 2.0, 1.0
|
||||
, 2.0, 4.0, 4.0, 2.0
|
||||
, 1.0, 2.0, 2.0, 1.0 ]
|
||||
out = col2im 2 2 1 1 3 4 . im2colUnsafe 2 2 1 1 $ input
|
||||
out = col2im 2 2 1 1 3 4 . im2col 2 2 1 1 $ input
|
||||
in expected === out
|
||||
|
||||
prop_simple_conv_forwards = once $
|
||||
@ -159,25 +104,18 @@ prop_simple_conv_forwards = once $
|
||||
[ 1.0, 2.0, 5.0
|
||||
, 3.0, 4.0, 6.0] :: HStatic.L 2 3)
|
||||
|
||||
expect = [ HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2] :: [HStatic.L 1 2]
|
||||
expect = HStatic.matrix
|
||||
[ -3.0 , -4.0
|
||||
, -1.0 , 1.0
|
||||
, 5.0 , 9.0
|
||||
, -7.0 , -10.0 ] :: HStatic.L 4 2
|
||||
out = runForwards convLayer input :: S' ('D3 1 2 4)
|
||||
|
||||
grad = S3D' ( mkVector
|
||||
[ HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2] ) :: S' ('D3 1 2 4)
|
||||
grad = S3D' ( HStatic.matrix
|
||||
[ 1 , 0
|
||||
, 0 , 0
|
||||
, 0 , 0
|
||||
, 0 , 1 ] :: HStatic.L 4 2 ) :: S' ('D3 1 2 4)
|
||||
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
@ -186,19 +124,19 @@ prop_simple_conv_forwards = once $
|
||||
|
||||
in case (out, inX, nc) of
|
||||
(S3D' out' , S2D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
-> (HStatic.extract expect === HStatic.extract out')
|
||||
.&&. (HStatic.extract expectBack === HStatic.extract inX')
|
||||
.&&. (HStatic.extract expectedGradient === HStatic.extract backGrad)
|
||||
|
||||
prop_vid2col_no_stride = once $
|
||||
let input = [(3><4)
|
||||
let input = (6><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
, (3><4)
|
||||
[ 21.0, 22.0, 23.0, 24.0
|
||||
, 9.0, 10.0, 11.0, 12.0
|
||||
-- -- --
|
||||
, 21.0, 22.0, 23.0, 24.0
|
||||
, 25.0, 26.0, 27.0, 28.0
|
||||
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||
, 29.0, 30.0, 31.0, 32.0 ]
|
||||
expected = (6><8)
|
||||
[ 1.0, 2.0, 5.0, 6.0 , 21.0, 22.0, 25.0, 26.0
|
||||
, 2.0, 3.0, 6.0, 7.0 , 22.0, 23.0, 26.0, 27.0
|
||||
@ -206,38 +144,36 @@ prop_vid2col_no_stride = once $
|
||||
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||
, 6.0, 7.0, 10.0, 11.0 , 26.0, 27.0, 30.0, 31.0
|
||||
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||
out = vid2colUnsafe 2 2 1 1 3 4 input
|
||||
out_c = vid2col 2 2 1 1 3 4 input
|
||||
in expected === out .&&. expected === out_c
|
||||
in expected === out_c
|
||||
|
||||
prop_vid2col_stride = once $
|
||||
let input = [(3><4)
|
||||
let input = (6><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
, (3><4)
|
||||
[ 21.0, 22.0, 23.0, 24.0
|
||||
, 9.0, 10.0, 11.0, 12.0
|
||||
-- -- -- -- --
|
||||
, 21.0, 22.0, 23.0, 24.0
|
||||
, 25.0, 26.0, 27.0, 28.0
|
||||
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||
, 29.0, 30.0, 31.0, 32.0 ]
|
||||
expected = (4><8)
|
||||
[ 1.0, 2.0, 5.0, 6.0 , 21.0, 22.0, 25.0, 26.0
|
||||
, 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0
|
||||
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||
out = vid2colUnsafe 2 2 1 2 3 4 input
|
||||
out_c = vid2col 2 2 1 2 3 4 input
|
||||
in expected === out .&&. expected === out_c
|
||||
in expected === out_c
|
||||
|
||||
prop_vid2col_invert = once $
|
||||
let input = [(3><4)
|
||||
let input = (6><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
, (3><4)
|
||||
[ 21.0, 22.0, 23.0, 24.0
|
||||
, 9.0, 10.0, 11.0, 12.0
|
||||
-- -- -- -- --
|
||||
, 21.0, 22.0, 23.0, 24.0
|
||||
, 25.0, 26.0, 27.0, 28.0
|
||||
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||
out = col2vid 3 2 3 2 3 4 . vid2colUnsafe 3 2 3 2 3 4 $ input
|
||||
, 29.0, 30.0, 31.0, 32.0 ]
|
||||
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
|
||||
in input === out
|
||||
|
||||
|
||||
@ -266,29 +202,22 @@ prop_single_conv_forwards = once $
|
||||
|
||||
convLayer = Convolution myKernel zeroKernel :: Convolution 1 4 2 2 1 1
|
||||
|
||||
input = S3D' ( mkVector [HStatic.matrix
|
||||
input = S3D' ( HStatic.matrix
|
||||
[ 1.0, 2.0, 5.0
|
||||
, 3.0, 4.0, 6.0] :: HStatic.L 2 3] ) :: S' ('D3 2 3 1)
|
||||
, 3.0, 4.0, 6.0] :: HStatic.L 2 3 ) :: S' ('D3 2 3 1)
|
||||
|
||||
expect = [HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2] :: [HStatic.L 1 2]
|
||||
expect = HStatic.matrix
|
||||
[ -3.0 , -4.0
|
||||
, -1.0 , 1.0
|
||||
, 5.0 , 9.0
|
||||
, -7.0 , -10.0 ] :: HStatic.L 4 2
|
||||
out = runForwards convLayer input :: S' ('D3 1 2 4)
|
||||
|
||||
grad = S3D' ( mkVector
|
||||
[HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2] ) :: S' ('D3 1 2 4)
|
||||
grad = S3D' (HStatic.matrix
|
||||
[ 1 , 0
|
||||
, 0 , 0
|
||||
, 0 , 0
|
||||
, 0 , 1 ] :: HStatic.L 4 2 ) :: S' ('D3 1 2 4)
|
||||
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
@ -297,8 +226,8 @@ prop_single_conv_forwards = once $
|
||||
|
||||
in case (out, inX, nc) of
|
||||
(S3D' out' , S3D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
.&&. ([HStatic.extract expectBack] === (HStatic.extract <$> vecToList inX'))
|
||||
-> (HStatic.extract expect === HStatic.extract out')
|
||||
.&&. (HStatic.extract expectBack === HStatic.extract inX')
|
||||
.&&. (HStatic.extract expectedGradient === HStatic.extract backGrad)
|
||||
|
||||
return []
|
||||
|
@ -18,7 +18,7 @@ prop_pool = once $
|
||||
expected = (2><3)
|
||||
[ 14.0, 14.0, 8.0
|
||||
, 10.0, 11.0, 12.0 ]
|
||||
out = poolForward 2 2 1 1 2 3 input
|
||||
out = poolForward 1 3 4 2 2 1 1 input
|
||||
in expected === out
|
||||
|
||||
prop_pool_rectangular = once $
|
||||
@ -29,7 +29,23 @@ prop_pool_rectangular = once $
|
||||
expected = (2><2)
|
||||
[ 14.0, 14.0
|
||||
, 11.0, 12.0 ]
|
||||
out = poolForward 2 3 1 1 2 2 input
|
||||
out = poolForward 1 3 4 2 3 1 1 input
|
||||
in expected === out
|
||||
|
||||
prop_pool_channels = once $
|
||||
let input = (6><4)
|
||||
[ 1.0, 14.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0
|
||||
, 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
expected = (4><2)
|
||||
[ 14.0, 14.0
|
||||
, 11.0, 12.0
|
||||
, 7.0, 8.0
|
||||
, 11.0, 12.0 ]
|
||||
out = poolForward 2 3 4 2 3 1 1 input
|
||||
in expected === out
|
||||
|
||||
prop_pool_backwards = once $
|
||||
@ -44,7 +60,7 @@ prop_pool_backwards = once $
|
||||
[ 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, -6.0, -7.0, -8.0
|
||||
, 0.0,-10.0,-11.0,-12.0 ]
|
||||
out = poolBackward 2 2 1 1 input grads
|
||||
out = poolBackward 1 3 4 2 2 1 1 input grads
|
||||
in expected === out
|
||||
|
||||
prop_pool_backwards_additive = once $
|
||||
@ -59,7 +75,7 @@ prop_pool_backwards_additive = once $
|
||||
[-6.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0,-18.0,-20.0
|
||||
,-10.0, 0.0, 0.0, 0.0 ]
|
||||
out = poolBackward 2 2 1 1 input grads
|
||||
out = poolBackward 1 3 4 2 2 1 1 input grads
|
||||
in expected === out
|
||||
|
||||
return []
|
||||
|
Loading…
Reference in New Issue
Block a user