mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-25 05:34:37 +03:00
Make things faster
This commit is contained in:
parent
d360438fc0
commit
b090b5f073
@ -40,6 +40,7 @@ library
|
|||||||
Grenade.Core.Shape
|
Grenade.Core.Shape
|
||||||
Grenade.Layers.Crop
|
Grenade.Layers.Crop
|
||||||
Grenade.Layers.Convolution
|
Grenade.Layers.Convolution
|
||||||
|
Grenade.Layers.Convolution.Internal
|
||||||
Grenade.Layers.Dropout
|
Grenade.Layers.Dropout
|
||||||
Grenade.Layers.FullyConnected
|
Grenade.Layers.FullyConnected
|
||||||
Grenade.Layers.Flatten
|
Grenade.Layers.Flatten
|
||||||
|
@ -51,7 +51,7 @@ class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
|
|||||||
-- layer gave from the input and the back propagated derivatives from
|
-- layer gave from the input and the back propagated derivatives from
|
||||||
-- the layer above.
|
-- the layer above.
|
||||||
-- Returns the gradient layer and the derivatives to push back further.
|
-- Returns the gradient layer and the derivatives to push back further.
|
||||||
runBackards :: x -> S' i -> S' o -> (Gradient x, S' i)
|
runBackwards :: x -> S' i -> S' o -> (Gradient x, S' i)
|
||||||
|
|
||||||
-- | Type of a network.
|
-- | Type of a network.
|
||||||
-- The [*] type specifies the types of the layers. This is needed for parallel
|
-- The [*] type specifies the types of the layers. This is needed for parallel
|
||||||
|
@ -32,7 +32,7 @@ backPropagate network input target =
|
|||||||
-- recursively run the rest of the network, and get the gradients from above.
|
-- recursively run the rest of the network, and get the gradients from above.
|
||||||
(n', dWs') = go y n
|
(n', dWs') = go y n
|
||||||
-- calculate the gradient for this layer to pass down,
|
-- calculate the gradient for this layer to pass down,
|
||||||
(layer', dWs) = runBackards layer x dWs'
|
(layer', dWs) = runBackwards layer x dWs'
|
||||||
|
|
||||||
in (layer' :/> n', dWs)
|
in (layer' :/> n', dWs)
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ backPropagate network input target =
|
|||||||
go !x (O layer)
|
go !x (O layer)
|
||||||
= let y = runForwards layer x
|
= let y = runForwards layer x
|
||||||
-- the gradient (how much y affects the error)
|
-- the gradient (how much y affects the error)
|
||||||
(layer', dWs) = runBackards layer x (y - target)
|
(layer', dWs) = runBackwards layer x (y - target)
|
||||||
|
|
||||||
in (OG layer', dWs)
|
in (OG layer', dWs)
|
||||||
|
|
||||||
|
@ -16,11 +16,6 @@ module Grenade.Layers.Convolution (
|
|||||||
Convolution (..)
|
Convolution (..)
|
||||||
, Convolution' (..)
|
, Convolution' (..)
|
||||||
, randomConvolution
|
, randomConvolution
|
||||||
, im2col
|
|
||||||
, vid2col
|
|
||||||
, col2im
|
|
||||||
, col2vid
|
|
||||||
, fittingStarts
|
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad.Random hiding ( fromList )
|
import Control.Monad.Random hiding ( fromList )
|
||||||
@ -36,6 +31,7 @@ import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
|
|||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
import Grenade.Core.Vector
|
import Grenade.Core.Vector
|
||||||
|
import Grenade.Layers.Convolution.Internal
|
||||||
|
|
||||||
-- | A convolution layer for a neural network.
|
-- | A convolution layer for a neural network.
|
||||||
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
||||||
@ -153,18 +149,21 @@ instance ( KnownNat kernelRows
|
|||||||
runForwards (Convolution kernel _) (S2D' input) =
|
runForwards (Convolution kernel _) (S2D' input) =
|
||||||
let ex = extract input
|
let ex = extract input
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
||||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
c = im2col kx ky sx sy ex
|
c = im2colUnsafe kx ky sx sy ix iy ex
|
||||||
mt = c LA.<> ek
|
mt = c LA.<> ek
|
||||||
r = col2vid 1 1 1 1 ox oy mt
|
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
||||||
rs = fmap (fromJust . create) r
|
rs = fmap (fromJust . create) r
|
||||||
in S3D' $ mkVector rs
|
in S3D' $ mkVector rs
|
||||||
runBackards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
|
|
||||||
|
runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
|
||||||
let ex = extract input
|
let ex = extract input
|
||||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
@ -174,17 +173,19 @@ instance ( KnownNat kernelRows
|
|||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
c = im2col kx ky sx sy ex
|
fl = fromIntegral $ natVal (Proxy :: Proxy filters)
|
||||||
|
|
||||||
|
c = im2colUnsafe kx ky sx sy ix iy ex
|
||||||
|
|
||||||
eo = vecToList $ fmap extract dEdy
|
eo = vecToList $ fmap extract dEdy
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
|
|
||||||
vs = vid2col 1 1 1 1 ox oy eo
|
vs = vid2colUnsafe fl 1 1 1 1 ox oy eo
|
||||||
|
|
||||||
kN = fromJust . create $ tr c LA.<> vs
|
kN = fromJust . create $ tr c LA.<> vs
|
||||||
dW = vs LA.<> tr ek
|
dW = vs LA.<> tr ek
|
||||||
|
|
||||||
xW = col2im kx ky sx sy ix iy dW
|
xW = col2imUnsafe kx ky sx sy ix iy dW
|
||||||
in (Convolution' kN, S2D' . fromJust . create $ xW)
|
in (Convolution' kN, S2D' . fromJust . create $ xW)
|
||||||
|
|
||||||
|
|
||||||
@ -215,12 +216,13 @@ instance ( KnownNat kernelRows
|
|||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
c = vid2col kx ky sx sy ix iy ex
|
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
|
||||||
|
c = vid2colUnsafe ch kx ky sx sy ix iy ex
|
||||||
mt = c LA.<> ek
|
mt = c LA.<> ek
|
||||||
r = col2vid 1 1 1 1 ox oy mt
|
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
||||||
rs = fmap (fromJust . create) r
|
rs = fmap (fromJust . create) r
|
||||||
in S3D' $ mkVector rs
|
in S3D' $ mkVector rs
|
||||||
runBackards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
|
runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
|
||||||
let ex = vecToList $ fmap extract input
|
let ex = vecToList $ fmap extract input
|
||||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
@ -230,77 +232,18 @@ instance ( KnownNat kernelRows
|
|||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
c = vid2col kx ky sx sy ix iy ex
|
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
|
||||||
|
fl = fromIntegral $ natVal (Proxy :: Proxy filters)
|
||||||
|
c = vid2colUnsafe ch kx ky sx sy ix iy ex
|
||||||
|
|
||||||
eo = vecToList $ fmap extract dEdy
|
eo = vecToList $ fmap extract dEdy
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
|
|
||||||
vs = vid2col 1 1 1 1 ox oy eo
|
vs = vid2colUnsafe fl 1 1 1 1 ox oy eo
|
||||||
|
|
||||||
kN = fromJust . create $ tr c LA.<> vs
|
kN = fromJust . create $ tr c LA.<> vs
|
||||||
|
|
||||||
dW = vs LA.<> tr ek
|
dW = vs LA.<> tr ek
|
||||||
|
|
||||||
xW = col2vid kx ky sx sy ix iy dW
|
xW = col2vidUnsafe kx ky sx sy ix iy dW
|
||||||
in (Convolution' kN, S3D' . mkVector . fmap (fromJust . create) $ xW)
|
in (Convolution' kN, S3D' . mkVector . fmap (fromJust . create) $ xW)
|
||||||
|
|
||||||
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
|
||||||
im2col nrows ncols srows scols m =
|
|
||||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
|
||||||
in im2colFit starts nrows ncols m
|
|
||||||
|
|
||||||
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
|
|
||||||
im2colFit starts nrows ncols m =
|
|
||||||
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
|
|
||||||
in fromRows imRows
|
|
||||||
|
|
||||||
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
|
||||||
vid2col nrows ncols srows scols inputrows inputcols ms =
|
|
||||||
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
|
|
||||||
subs = fmap (im2colFit starts nrows ncols) ms
|
|
||||||
in foldl1 (|||) subs
|
|
||||||
|
|
||||||
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
|
||||||
col2vid nrows ncols srows scols drows dcols m =
|
|
||||||
let starts = fittingStart (cols m) (nrows * ncols) (nrows * ncols)
|
|
||||||
r = rows m
|
|
||||||
mats = fmap (\s -> subMatrix (0,s) (r, nrows * ncols) m) starts
|
|
||||||
colSts = fittingStarts drows nrows srows dcols ncols scols
|
|
||||||
in fmap (col2imfit colSts nrows ncols drows dcols) mats
|
|
||||||
|
|
||||||
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
|
||||||
col2im krows kcols srows scols drows dcols m =
|
|
||||||
let starts = fittingStarts drows krows srows dcols kcols scols
|
|
||||||
in col2imfit starts krows kcols drows dcols m
|
|
||||||
|
|
||||||
col2imfit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
|
||||||
col2imfit starts krows kcols drows dcols m =
|
|
||||||
let indicies = fmap (\[a,b] -> (a,b)) $ sequence [[0..(krows-1)], [0..(kcols-1)]]
|
|
||||||
convs = fmap (zip indicies . toList) . toRows $ m
|
|
||||||
pairs = zip convs starts
|
|
||||||
accums = concat $ fmap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
|
|
||||||
in accum (LA.konst 0 (drows, dcols)) (+) accums
|
|
||||||
|
|
||||||
|
|
||||||
-- | These functions are not even remotely safe, but it's only called from the statically typed
|
|
||||||
-- commands, so we should be good ?!?!?
|
|
||||||
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
|
|
||||||
-- convolution. Takes into account the stride and kernel size.
|
|
||||||
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
|
||||||
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
|
||||||
let rs = fittingStart nrows kernelrows steprows
|
|
||||||
cs = fittingStart ncols kernelcols stepcolsh
|
|
||||||
ls = sequence [rs, cs]
|
|
||||||
in fmap (\[a,b] -> (a,b)) ls
|
|
||||||
|
|
||||||
-- | Returns the starting sub vector which fit inside the larger vector for the
|
|
||||||
-- convolution. Takes into account the stride and kernel size.
|
|
||||||
fittingStart :: Int -> Int -> Int -> [Int]
|
|
||||||
fittingStart width kernel steps =
|
|
||||||
let go left | left + kernel < width
|
|
||||||
= left : go (left + steps)
|
|
||||||
| left + kernel == width
|
|
||||||
= left : []
|
|
||||||
| otherwise
|
|
||||||
= error "Kernel and step do not fit in matrix."
|
|
||||||
in go 0
|
|
||||||
|
238
src/Grenade/Layers/Convolution/Internal.hs
Normal file
238
src/Grenade/Layers/Convolution/Internal.hs
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
module Grenade.Layers.Convolution.Internal (
|
||||||
|
im2col
|
||||||
|
-- , im2colUnsafe
|
||||||
|
, vid2col
|
||||||
|
, col2im
|
||||||
|
, col2imFit
|
||||||
|
, col2vid
|
||||||
|
|
||||||
|
, col2vidUnsafe
|
||||||
|
, col2imUnsafe
|
||||||
|
, im2colUnsafe
|
||||||
|
, vid2colUnsafe
|
||||||
|
, fittingStarts
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.ST
|
||||||
|
import Control.Parallel.Strategies ( parMap, rseq )
|
||||||
|
|
||||||
|
import Data.STRef
|
||||||
|
import Data.Foldable ( forM_ )
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||||
|
import qualified Numeric.LinearAlgebra as LA
|
||||||
|
import qualified Numeric.LinearAlgebra.Devel as U
|
||||||
|
|
||||||
|
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
im2col nrows ncols srows scols m =
|
||||||
|
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||||
|
in im2colFit starts nrows ncols m
|
||||||
|
|
||||||
|
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
im2colFit starts nrows ncols m =
|
||||||
|
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
|
||||||
|
in fromRows imRows
|
||||||
|
|
||||||
|
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||||
|
vid2col nrows ncols srows scols inputrows inputcols ms =
|
||||||
|
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
|
||||||
|
subs = parMap rseq (im2colFit starts nrows ncols) ms
|
||||||
|
in foldl1 (|||) subs
|
||||||
|
|
||||||
|
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||||
|
col2vid krows kcols srows scols drows dcols m =
|
||||||
|
let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
|
||||||
|
r = rows m
|
||||||
|
mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
|
||||||
|
colSts = fittingStarts drows krows srows dcols kcols scols
|
||||||
|
in parMap rseq (col2imFit colSts krows kcols drows dcols) mats
|
||||||
|
|
||||||
|
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
col2im krows kcols srows scols drows dcols m =
|
||||||
|
let starts = fittingStarts drows krows srows dcols kcols scols
|
||||||
|
in col2imFit starts krows kcols drows dcols m
|
||||||
|
|
||||||
|
-- | These functions are not even remotely safe, but it's only called from the statically typed
|
||||||
|
-- commands, so we should be good ?!?!?
|
||||||
|
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||||
|
-- convolution. Takes into account the stride and kernel size.
|
||||||
|
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
||||||
|
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
||||||
|
let rs = fittingStart nrows kernelrows steprows
|
||||||
|
cs = fittingStart ncols kernelcols stepcolsh
|
||||||
|
ls = sequence [rs, cs]
|
||||||
|
in fmap (\[a,b] -> (a,b)) ls
|
||||||
|
|
||||||
|
-- | Returns the starting sub vector which fit inside the larger vector for the
|
||||||
|
-- convolution. Takes into account the stride and kernel size.
|
||||||
|
fittingStart :: Int -> Int -> Int -> [Int]
|
||||||
|
fittingStart width kernel steps =
|
||||||
|
let go left | left + kernel < width
|
||||||
|
= left : go (left + steps)
|
||||||
|
| left + kernel == width
|
||||||
|
= [left]
|
||||||
|
| otherwise
|
||||||
|
= error "Kernel and step do not fit in matrix."
|
||||||
|
in go 0
|
||||||
|
|
||||||
|
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
col2imFit starts krows kcols drows dcols m =
|
||||||
|
let indicies = (\[a,b] -> (a,b)) <$> sequence [[0..(krows-1)], [0..(kcols-1)]]
|
||||||
|
convs = fmap (zip indicies . toList) . toRows $ m
|
||||||
|
pairs = zip convs starts
|
||||||
|
accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
|
||||||
|
in accum (LA.konst 0 (drows, dcols)) (+) accums
|
||||||
|
|
||||||
|
-- void col2im_cpu(const Dtype* data_col, const int channels,
|
||||||
|
-- const int height, const int width, const int kernel_h, const int kernel_w,
|
||||||
|
-- const int pad_h, const int pad_w,
|
||||||
|
-- const int stride_h, const int stride_w,
|
||||||
|
-- const int dilation_h, const int dilation_w,
|
||||||
|
-- Dtype* data_im) {
|
||||||
|
-- caffe_set(height * width * channels, Dtype(0), data_im);
|
||||||
|
-- const int output_h = (height + 2 * pad_h -
|
||||||
|
-- (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||||
|
-- const int output_w = (width + 2 * pad_w -
|
||||||
|
-- (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||||
|
-- const int channel_size = height * width;
|
||||||
|
-- for (int channel = channels; channel--; data_im += channel_size) {
|
||||||
|
-- for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
|
||||||
|
-- for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
||||||
|
-- int input_row = -pad_h + kernel_row * dilation_h;
|
||||||
|
-- for (int output_rows = output_h; output_rows; output_rows--) {
|
||||||
|
-- if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
|
||||||
|
-- data_col += output_w;
|
||||||
|
-- } else {
|
||||||
|
-- int input_col = -pad_w + kernel_col * dilation_w;
|
||||||
|
-- for (int output_col = output_w; output_col; output_col--) {
|
||||||
|
-- if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
|
||||||
|
-- data_im[input_row * width + input_col] += *data_col;
|
||||||
|
-- }
|
||||||
|
-- data_col++;
|
||||||
|
-- input_col += stride_w;
|
||||||
|
-- }
|
||||||
|
-- }
|
||||||
|
-- input_row += stride_h;
|
||||||
|
-- }
|
||||||
|
-- }
|
||||||
|
-- }
|
||||||
|
-- }
|
||||||
|
-- }
|
||||||
|
|
||||||
|
|
||||||
|
-- let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
|
||||||
|
-- r = rows m
|
||||||
|
-- mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
|
||||||
|
-- in parMap rseq (col2imUnsafe krows kcols srows scols drows dcols) mats
|
||||||
|
|
||||||
|
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
|
||||||
|
let columnMatrixRows = rows columnMatrix
|
||||||
|
|
||||||
|
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
||||||
|
|
||||||
|
offsetR <- newSTRef 0
|
||||||
|
offsetC <- newSTRef 0
|
||||||
|
|
||||||
|
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
||||||
|
inputColumn <- newSTRef 0
|
||||||
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
|
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||||
|
ic <- readSTRef inputColumn
|
||||||
|
offsetR' <- readSTRef offsetR
|
||||||
|
offsetC' <- readSTRef offsetC
|
||||||
|
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir,ic))
|
||||||
|
modifySTRef inputColumn (+1)
|
||||||
|
|
||||||
|
offsetC' <- readSTRef offsetC
|
||||||
|
if offsetC' + kernelColumns < destinationCols
|
||||||
|
then modifySTRef offsetC (+ strideColumns)
|
||||||
|
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||||
|
|
||||||
|
return dataIm
|
||||||
|
|
||||||
|
col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||||
|
col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do
|
||||||
|
let columnMatrixRows = rows columnMatrix
|
||||||
|
let filters = cols columnMatrix `div` (kernelRows * kernelColumns)
|
||||||
|
|
||||||
|
dataIms <- traverse (\_ -> U.newMatrix 0 destinationRows destinationCols) [0 .. filters-1]
|
||||||
|
|
||||||
|
offsetR <- newSTRef 0
|
||||||
|
offsetC <- newSTRef 0
|
||||||
|
offsetM <- newSTRef 0
|
||||||
|
|
||||||
|
forM_ dataIms $ \dataIm -> do
|
||||||
|
offsetM' <- readSTRef offsetM
|
||||||
|
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
||||||
|
inputColumn <- newSTRef 0
|
||||||
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
|
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||||
|
ic <- readSTRef inputColumn
|
||||||
|
offsetR' <- readSTRef offsetR
|
||||||
|
offsetC' <- readSTRef offsetC
|
||||||
|
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir, ic + offsetM'))
|
||||||
|
modifySTRef inputColumn (+1)
|
||||||
|
|
||||||
|
offsetC' <- readSTRef offsetC
|
||||||
|
if offsetC' + kernelColumns < destinationCols
|
||||||
|
then modifySTRef offsetC (+ strideColumns)
|
||||||
|
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||||
|
|
||||||
|
writeSTRef offsetR 0
|
||||||
|
writeSTRef offsetC 0
|
||||||
|
modifySTRef offsetM (+ (kernelRows * kernelColumns))
|
||||||
|
|
||||||
|
traverse U.freezeMatrix dataIms
|
||||||
|
|
||||||
|
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||||
|
vid2colUnsafe channels kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
|
||||||
|
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
|
||||||
|
matWidth = kernelRows * kernelColumns
|
||||||
|
destinationRows = 1 + (vidrows - kernelRows) `div` striderows
|
||||||
|
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols
|
||||||
|
destinationSize = destinationRows * destinationCols
|
||||||
|
|
||||||
|
dataCol <- U.newMatrix 0 destinationSize (channels * matWidth)
|
||||||
|
|
||||||
|
offsetC <- newSTRef 0
|
||||||
|
|
||||||
|
forM_ dataVid $ \dataIm -> do
|
||||||
|
inputRow <- newSTRef 0
|
||||||
|
offsetC' <- readSTRef offsetC
|
||||||
|
forM_ starts $ \(startRow, startCol) -> do
|
||||||
|
inputColumn <- newSTRef 0
|
||||||
|
inputRow' <- readSTRef inputRow
|
||||||
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
|
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||||
|
inputColumn' <- readSTRef inputColumn
|
||||||
|
U.modifyMatrix dataCol inputRow' (inputColumn' + offsetC') (+ atIndex dataIm (kr + startRow, kc + startCol))
|
||||||
|
modifySTRef inputColumn (+1)
|
||||||
|
modifySTRef inputRow (+1)
|
||||||
|
|
||||||
|
modifySTRef offsetC (+ matWidth)
|
||||||
|
|
||||||
|
return dataCol
|
||||||
|
|
||||||
|
im2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
im2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataIm = U.runSTMatrix $ do
|
||||||
|
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
|
||||||
|
matWidth = kernelRows * kernelColumns
|
||||||
|
destinationRows = 1 + (vidrows - kernelRows) `div` striderows
|
||||||
|
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols
|
||||||
|
destinationSize = destinationRows * destinationCols
|
||||||
|
|
||||||
|
dataCol <- U.newMatrix 0 destinationSize matWidth
|
||||||
|
|
||||||
|
inputRow <- newSTRef 0
|
||||||
|
forM_ starts $ \(startRow, startCol) -> do
|
||||||
|
inputColumn <- newSTRef 0
|
||||||
|
inputRow' <- readSTRef inputRow
|
||||||
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
|
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||||
|
inputColumn' <- readSTRef inputColumn
|
||||||
|
U.modifyMatrix dataCol inputRow' inputColumn' (+ atIndex dataIm (kr + startRow, kc + startCol))
|
||||||
|
modifySTRef inputColumn (+1)
|
||||||
|
modifySTRef inputRow (+1)
|
||||||
|
|
||||||
|
return dataCol
|
@ -58,7 +58,7 @@ instance ( KnownNat cropLeft
|
|||||||
m = extract input
|
m = extract input
|
||||||
r = subMatrix (cropt, cropl) (nrows, ncols) m
|
r = subMatrix (cropt, cropl) (nrows, ncols) m
|
||||||
in S2D' . fromJust . create $ r
|
in S2D' . fromJust . create $ r
|
||||||
runBackards _ _ (S2D' dEdy) =
|
runBackwards _ _ (S2D' dEdy) =
|
||||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||||
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
|
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
|
||||||
|
@ -47,5 +47,5 @@ randomDropout rate = do
|
|||||||
instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where
|
instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where
|
||||||
runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops
|
runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops
|
||||||
runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x
|
runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x
|
||||||
runBackards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops)
|
runBackwards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops)
|
||||||
runBackards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x)
|
runBackwards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x)
|
||||||
|
@ -33,11 +33,11 @@ instance UpdateLayer FlattenLayer where
|
|||||||
|
|
||||||
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
||||||
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
|
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
|
||||||
runBackards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
|
runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
|
||||||
|
|
||||||
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
||||||
runForwards _ (S3D' y) = S1D' . raiseShapeError . create . vjoin . vecToList . fmap (flatten . extract) $ y
|
runForwards _ (S3D' y) = S1D' . raiseShapeError . create . vjoin . vecToList . fmap (flatten . extract) $ y
|
||||||
runBackards _ _ (S1D' o) =
|
runBackwards _ _ (S1D' o) =
|
||||||
let x' = fromIntegral $ natVal (Proxy :: Proxy x)
|
let x' = fromIntegral $ natVal (Proxy :: Proxy x)
|
||||||
y' = fromIntegral $ natVal (Proxy :: Proxy y)
|
y' = fromIntegral $ natVal (Proxy :: Proxy y)
|
||||||
z' = fromIntegral $ natVal (Proxy :: Proxy z)
|
z' = fromIntegral $ natVal (Proxy :: Proxy z)
|
||||||
|
@ -52,7 +52,7 @@ instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o)
|
|||||||
runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)
|
runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)
|
||||||
|
|
||||||
-- Run a backpropogation step for a full connected layer.
|
-- Run a backpropogation step for a full connected layer.
|
||||||
runBackards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
|
runBackwards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
|
||||||
let wB' = dEdy
|
let wB' = dEdy
|
||||||
mm' = dEdy `outer` x
|
mm' = dEdy `outer` x
|
||||||
-- calcluate derivatives for next step
|
-- calcluate derivatives for next step
|
||||||
|
@ -45,8 +45,8 @@ instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
|
|||||||
let yInput :: S' h = runForwards x input
|
let yInput :: S' h = runForwards x input
|
||||||
in runForwards y yInput
|
in runForwards y yInput
|
||||||
|
|
||||||
runBackards (x :$$ y) input backGradient =
|
runBackwards (x :$$ y) input backGradient =
|
||||||
let yInput :: S' h = runForwards x input
|
let yInput :: S' h = runForwards x input
|
||||||
(y', yGrad) = runBackards y yInput backGradient
|
(y', yGrad) = runBackwards y yInput backGradient
|
||||||
(x', xGrad) = runBackards x input yGrad
|
(x', xGrad) = runBackwards x input yGrad
|
||||||
in ((x', y'), xGrad)
|
in ((x', y'), xGrad)
|
||||||
|
@ -29,15 +29,15 @@ instance UpdateLayer Logit where
|
|||||||
|
|
||||||
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
|
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
|
||||||
runForwards _ (S1D' y) = S1D' (logistic y)
|
runForwards _ (S1D' y) = S1D' (logistic y)
|
||||||
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
|
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
|
instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
|
||||||
runForwards _ (S2D' y) = S2D' (logistic y)
|
runForwards _ (S2D' y) = S2D' (logistic y)
|
||||||
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
|
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
|
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
|
||||||
runForwards _ (S3D' y) = S3D' (fmap logistic y)
|
runForwards _ (S3D' y) = S3D' (fmap logistic y)
|
||||||
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy))
|
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy))
|
||||||
|
|
||||||
|
|
||||||
logistic :: Floating a => a -> a
|
logistic :: Floating a => a -> a
|
||||||
|
@ -58,7 +58,7 @@ instance ( KnownNat padLeft
|
|||||||
m = extract input
|
m = extract input
|
||||||
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
|
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
|
||||||
in S2D' . fromJust . create $ r
|
in S2D' . fromJust . create $ r
|
||||||
runBackards Pad _ (S2D' dEdy) =
|
runBackwards Pad _ (S2D' dEdy) =
|
||||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||||
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
|
@ -24,7 +24,7 @@ import GHC.TypeLits
|
|||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
import Grenade.Core.Vector
|
import Grenade.Core.Vector
|
||||||
import Grenade.Layers.Convolution
|
import Grenade.Layers.Convolution.Internal
|
||||||
|
|
||||||
import Numeric.LinearAlgebra hiding (uniformSample)
|
import Numeric.LinearAlgebra hiding (uniformSample)
|
||||||
import qualified Numeric.LinearAlgebra as LA
|
import qualified Numeric.LinearAlgebra as LA
|
||||||
@ -75,7 +75,7 @@ instance ( KnownNat kernelRows
|
|||||||
r = poolForward kx ky sx sy ox oy $ ex
|
r = poolForward kx ky sx sy ox oy $ ex
|
||||||
rs = fromJust . create $ r
|
rs = fromJust . create $ r
|
||||||
in S2D' $ rs
|
in S2D' $ rs
|
||||||
runBackards Pooling (S2D' input) (S2D' dEdy) =
|
runBackwards Pooling (S2D' input) (S2D' dEdy) =
|
||||||
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
@ -111,7 +111,7 @@ instance ( KnownNat kernelRows
|
|||||||
r = poolForwardList kx ky sx sy ix iy ox oy ex
|
r = poolForwardList kx ky sx sy ix iy ox oy ex
|
||||||
rs = fmap (fromJust . create) r
|
rs = fmap (fromJust . create) r
|
||||||
in S3D' rs
|
in S3D' rs
|
||||||
runBackards Pooling (S3D' input) (S3D' dEdy) =
|
runBackwards Pooling (S3D' input) (S3D' dEdy) =
|
||||||
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
@ -31,7 +31,7 @@ instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
|
|||||||
runForwards _ (S1D' y) = S1D' (relu y)
|
runForwards _ (S1D' y) = S1D' (relu y)
|
||||||
where
|
where
|
||||||
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
|
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
|
||||||
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
|
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
|
||||||
where
|
where
|
||||||
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
|
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
|
|||||||
runForwards _ (S2D' y) = S2D' (relu y)
|
runForwards _ (S2D' y) = S2D' (relu y)
|
||||||
where
|
where
|
||||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||||
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
|
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
|
||||||
where
|
where
|
||||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
|
||||||
@ -47,6 +47,6 @@ instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j
|
|||||||
runForwards _ (S3D' y) = S3D' (fmap relu y)
|
runForwards _ (S3D' y) = S3D' (fmap relu y)
|
||||||
where
|
where
|
||||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||||
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy))
|
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy))
|
||||||
where
|
where
|
||||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
@ -26,15 +26,15 @@ instance UpdateLayer Tanh where
|
|||||||
|
|
||||||
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
|
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
|
||||||
runForwards _ (S1D' y) = S1D' (tanh y)
|
runForwards _ (S1D' y) = S1D' (tanh y)
|
||||||
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
|
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
|
instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
|
||||||
runForwards _ (S2D' y) = S2D' (tanh y)
|
runForwards _ (S2D' y) = S2D' (tanh y)
|
||||||
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
|
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
|
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
|
||||||
runForwards _ (S3D' y) = S3D' (fmap tanh y)
|
runForwards _ (S3D' y) = S3D' (fmap tanh y)
|
||||||
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy))
|
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy))
|
||||||
|
|
||||||
tanh' :: (Floating a) => a -> a
|
tanh' :: (Floating a) => a -> a
|
||||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
||||||
|
@ -8,6 +8,7 @@ import Grenade.Core.Shape
|
|||||||
import Grenade.Core.Vector as Grenade
|
import Grenade.Core.Vector as Grenade
|
||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Layers.Convolution
|
import Grenade.Layers.Convolution
|
||||||
|
import Grenade.Layers.Convolution.Internal
|
||||||
|
|
||||||
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||||
import qualified Numeric.LinearAlgebra.Static as HStatic
|
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||||
@ -63,6 +64,17 @@ prop_im2col_sym_on_same_stride = once $
|
|||||||
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
|
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
|
||||||
in input === out
|
in input === out
|
||||||
|
|
||||||
|
-- If there's no overlap (stride is the same size as the kernel)
|
||||||
|
-- then col2im . im2col should be symmetric.
|
||||||
|
prop_im2colunsafe_sym_on_same_stride = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 3 4 $ input
|
||||||
|
in input === out
|
||||||
|
|
||||||
|
|
||||||
-- If there is an overlap, then the gradient passed back should be
|
-- If there is an overlap, then the gradient passed back should be
|
||||||
-- the sum of the gradients across the filters.
|
-- the sum of the gradients across the filters.
|
||||||
prop_im2col_col2im_additive = once $
|
prop_im2col_col2im_additive = once $
|
||||||
@ -127,7 +139,7 @@ prop_simple_conv_forwards = once $
|
|||||||
expectBack = (HStatic.matrix
|
expectBack = (HStatic.matrix
|
||||||
[ 1.0, 0.0, 0.0
|
[ 1.0, 0.0, 0.0
|
||||||
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||||
(nc, inX) = runBackards convLayer input grad
|
(nc, inX) = runBackwards convLayer input grad
|
||||||
|
|
||||||
in case (out, inX, nc) of
|
in case (out, inX, nc) of
|
||||||
(S3D' out' , S2D' inX', Convolution' backGrad)
|
(S3D' out' , S2D' inX', Convolution' backGrad)
|
||||||
@ -187,6 +199,19 @@ prop_vid2col_invert = once $
|
|||||||
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
|
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
|
||||||
in input === out
|
in input === out
|
||||||
|
|
||||||
|
prop_vid2col_invert_unsafe = once $
|
||||||
|
let input = [(3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
, (3><4)
|
||||||
|
[ 21.0, 22.0, 23.0, 24.0
|
||||||
|
, 25.0, 26.0, 27.0, 28.0
|
||||||
|
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||||
|
out = col2vidUnsafe 3 2 3 2 3 4 . vid2colUnsafe 2 3 2 3 2 3 4 $ input
|
||||||
|
in input === out
|
||||||
|
|
||||||
|
|
||||||
-- This test show that 2D convs act the same
|
-- This test show that 2D convs act the same
|
||||||
-- 3D convs with one layer
|
-- 3D convs with one layer
|
||||||
prop_single_conv_forwards = once $
|
prop_single_conv_forwards = once $
|
||||||
@ -239,7 +264,7 @@ prop_single_conv_forwards = once $
|
|||||||
expectBack = (HStatic.matrix
|
expectBack = (HStatic.matrix
|
||||||
[ 1.0, 0.0, 0.0
|
[ 1.0, 0.0, 0.0
|
||||||
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||||
(nc, inX) = runBackards convLayer input grad
|
(nc, inX) = runBackwards convLayer input grad
|
||||||
|
|
||||||
in case (out, inX, nc) of
|
in case (out, inX, nc) of
|
||||||
(S3D' out' , S3D' inX', Convolution' backGrad)
|
(S3D' out' , S3D' inX', Convolution' backGrad)
|
||||||
|
Loading…
Reference in New Issue
Block a user