Make things faster

This commit is contained in:
Huw Campbell 2016-12-08 01:16:20 +11:00
parent d360438fc0
commit b090b5f073
16 changed files with 314 additions and 107 deletions

View File

@ -40,6 +40,7 @@ library
Grenade.Core.Shape Grenade.Core.Shape
Grenade.Layers.Crop Grenade.Layers.Crop
Grenade.Layers.Convolution Grenade.Layers.Convolution
Grenade.Layers.Convolution.Internal
Grenade.Layers.Dropout Grenade.Layers.Dropout
Grenade.Layers.FullyConnected Grenade.Layers.FullyConnected
Grenade.Layers.Flatten Grenade.Layers.Flatten

View File

@ -51,7 +51,7 @@ class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
-- layer gave from the input and the back propagated derivatives from -- layer gave from the input and the back propagated derivatives from
-- the layer above. -- the layer above.
-- Returns the gradient layer and the derivatives to push back further. -- Returns the gradient layer and the derivatives to push back further.
runBackards :: x -> S' i -> S' o -> (Gradient x, S' i) runBackwards :: x -> S' i -> S' o -> (Gradient x, S' i)
-- | Type of a network. -- | Type of a network.
-- The [*] type specifies the types of the layers. This is needed for parallel -- The [*] type specifies the types of the layers. This is needed for parallel

View File

@ -32,7 +32,7 @@ backPropagate network input target =
-- recursively run the rest of the network, and get the gradients from above. -- recursively run the rest of the network, and get the gradients from above.
(n', dWs') = go y n (n', dWs') = go y n
-- calculate the gradient for this layer to pass down, -- calculate the gradient for this layer to pass down,
(layer', dWs) = runBackards layer x dWs' (layer', dWs) = runBackwards layer x dWs'
in (layer' :/> n', dWs) in (layer' :/> n', dWs)
@ -40,7 +40,7 @@ backPropagate network input target =
go !x (O layer) go !x (O layer)
= let y = runForwards layer x = let y = runForwards layer x
-- the gradient (how much y affects the error) -- the gradient (how much y affects the error)
(layer', dWs) = runBackards layer x (y - target) (layer', dWs) = runBackwards layer x (y - target)
in (OG layer', dWs) in (OG layer', dWs)

View File

@ -16,26 +16,22 @@ module Grenade.Layers.Convolution (
Convolution (..) Convolution (..)
, Convolution' (..) , Convolution' (..)
, randomConvolution , randomConvolution
, im2col
, vid2col
, col2im
, col2vid
, fittingStarts
) where ) where
import Control.Monad.Random hiding (fromList) import Control.Monad.Random hiding ( fromList )
import Data.Maybe import Data.Maybe
import Data.Proxy import Data.Proxy
import Data.Singletons.TypeLits import Data.Singletons.TypeLits
import GHC.TypeLits import GHC.TypeLits
import Numeric.LinearAlgebra hiding (uniformSample, konst) import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra as LA import qualified Numeric.LinearAlgebra as LA
import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows) import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
import Grenade.Core.Network import Grenade.Core.Network
import Grenade.Core.Shape import Grenade.Core.Shape
import Grenade.Core.Vector import Grenade.Core.Vector
import Grenade.Layers.Convolution.Internal
-- | A convolution layer for a neural network. -- | A convolution layer for a neural network.
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the -- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
@ -153,18 +149,21 @@ instance ( KnownNat kernelRows
runForwards (Convolution kernel _) (S2D' input) = runForwards (Convolution kernel _) (S2D' input) =
let ex = extract input let ex = extract input
ek = extract kernel ek = extract kernel
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols) ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
c = im2col kx ky sx sy ex c = im2colUnsafe kx ky sx sy ix iy ex
mt = c LA.<> ek mt = c LA.<> ek
r = col2vid 1 1 1 1 ox oy mt r = col2vidUnsafe 1 1 1 1 ox oy mt
rs = fmap (fromJust . create) r rs = fmap (fromJust . create) r
in S3D' $ mkVector rs in S3D' $ mkVector rs
runBackards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
let ex = extract input let ex = extract input
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols) iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
@ -174,17 +173,19 @@ instance ( KnownNat kernelRows
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
c = im2col kx ky sx sy ex fl = fromIntegral $ natVal (Proxy :: Proxy filters)
c = im2colUnsafe kx ky sx sy ix iy ex
eo = vecToList $ fmap extract dEdy eo = vecToList $ fmap extract dEdy
ek = extract kernel ek = extract kernel
vs = vid2col 1 1 1 1 ox oy eo vs = vid2colUnsafe fl 1 1 1 1 ox oy eo
kN = fromJust . create $ tr c LA.<> vs kN = fromJust . create $ tr c LA.<> vs
dW = vs LA.<> tr ek dW = vs LA.<> tr ek
xW = col2im kx ky sx sy ix iy dW xW = col2imUnsafe kx ky sx sy ix iy dW
in (Convolution' kN, S2D' . fromJust . create $ xW) in (Convolution' kN, S2D' . fromJust . create $ xW)
@ -215,12 +216,13 @@ instance ( KnownNat kernelRows
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
c = vid2col kx ky sx sy ix iy ex ch = fromIntegral $ natVal (Proxy :: Proxy channels)
c = vid2colUnsafe ch kx ky sx sy ix iy ex
mt = c LA.<> ek mt = c LA.<> ek
r = col2vid 1 1 1 1 ox oy mt r = col2vidUnsafe 1 1 1 1 ox oy mt
rs = fmap (fromJust . create) r rs = fmap (fromJust . create) r
in S3D' $ mkVector rs in S3D' $ mkVector rs
runBackards (Convolution kernel _) (S3D' input) (S3D' dEdy) = runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
let ex = vecToList $ fmap extract input let ex = vecToList $ fmap extract input
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols) iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
@ -230,77 +232,18 @@ instance ( KnownNat kernelRows
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
c = vid2col kx ky sx sy ix iy ex ch = fromIntegral $ natVal (Proxy :: Proxy channels)
fl = fromIntegral $ natVal (Proxy :: Proxy filters)
c = vid2colUnsafe ch kx ky sx sy ix iy ex
eo = vecToList $ fmap extract dEdy eo = vecToList $ fmap extract dEdy
ek = extract kernel ek = extract kernel
vs = vid2col 1 1 1 1 ox oy eo vs = vid2colUnsafe fl 1 1 1 1 ox oy eo
kN = fromJust . create $ tr c LA.<> vs kN = fromJust . create $ tr c LA.<> vs
dW = vs LA.<> tr ek dW = vs LA.<> tr ek
xW = col2vid kx ky sx sy ix iy dW xW = col2vidUnsafe kx ky sx sy ix iy dW
in (Convolution' kN, S3D' . mkVector . fmap (fromJust . create) $ xW) in (Convolution' kN, S3D' . mkVector . fmap (fromJust . create) $ xW)
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
im2col nrows ncols srows scols m =
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
in im2colFit starts nrows ncols m
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
im2colFit starts nrows ncols m =
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
in fromRows imRows
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
vid2col nrows ncols srows scols inputrows inputcols ms =
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
subs = fmap (im2colFit starts nrows ncols) ms
in foldl1 (|||) subs
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
col2vid nrows ncols srows scols drows dcols m =
let starts = fittingStart (cols m) (nrows * ncols) (nrows * ncols)
r = rows m
mats = fmap (\s -> subMatrix (0,s) (r, nrows * ncols) m) starts
colSts = fittingStarts drows nrows srows dcols ncols scols
in fmap (col2imfit colSts nrows ncols drows dcols) mats
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2im krows kcols srows scols drows dcols m =
let starts = fittingStarts drows krows srows dcols kcols scols
in col2imfit starts krows kcols drows dcols m
col2imfit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imfit starts krows kcols drows dcols m =
let indicies = fmap (\[a,b] -> (a,b)) $ sequence [[0..(krows-1)], [0..(kcols-1)]]
convs = fmap (zip indicies . toList) . toRows $ m
pairs = zip convs starts
accums = concat $ fmap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
in accum (LA.konst 0 (drows, dcols)) (+) accums
-- | These functions are not even remotely safe, but it's only called from the statically typed
-- commands, so we should be good ?!?!?
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
ls = sequence [rs, cs]
in fmap (\[a,b] -> (a,b)) ls
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= left : []
| otherwise
= error "Kernel and step do not fit in matrix."
in go 0

View File

@ -0,0 +1,238 @@
module Grenade.Layers.Convolution.Internal (
im2col
-- , im2colUnsafe
, vid2col
, col2im
, col2imFit
, col2vid
, col2vidUnsafe
, col2imUnsafe
, im2colUnsafe
, vid2colUnsafe
, fittingStarts
) where
import Control.Monad.ST
import Control.Parallel.Strategies ( parMap, rseq )
import Data.STRef
import Data.Foldable ( forM_ )
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra as LA
import qualified Numeric.LinearAlgebra.Devel as U
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
im2col nrows ncols srows scols m =
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
in im2colFit starts nrows ncols m
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
im2colFit starts nrows ncols m =
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
in fromRows imRows
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
vid2col nrows ncols srows scols inputrows inputcols ms =
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
subs = parMap rseq (im2colFit starts nrows ncols) ms
in foldl1 (|||) subs
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
col2vid krows kcols srows scols drows dcols m =
let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
r = rows m
mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
colSts = fittingStarts drows krows srows dcols kcols scols
in parMap rseq (col2imFit colSts krows kcols drows dcols) mats
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2im krows kcols srows scols drows dcols m =
let starts = fittingStarts drows krows srows dcols kcols scols
in col2imFit starts krows kcols drows dcols m
-- | These functions are not even remotely safe, but it's only called from the statically typed
-- commands, so we should be good ?!?!?
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
ls = sequence [rs, cs]
in fmap (\[a,b] -> (a,b)) ls
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= [left]
| otherwise
= error "Kernel and step do not fit in matrix."
in go 0
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imFit starts krows kcols drows dcols m =
let indicies = (\[a,b] -> (a,b)) <$> sequence [[0..(krows-1)], [0..(kcols-1)]]
convs = fmap (zip indicies . toList) . toRows $ m
pairs = zip convs starts
accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
in accum (LA.konst 0 (drows, dcols)) (+) accums
-- void col2im_cpu(const Dtype* data_col, const int channels,
-- const int height, const int width, const int kernel_h, const int kernel_w,
-- const int pad_h, const int pad_w,
-- const int stride_h, const int stride_w,
-- const int dilation_h, const int dilation_w,
-- Dtype* data_im) {
-- caffe_set(height * width * channels, Dtype(0), data_im);
-- const int output_h = (height + 2 * pad_h -
-- (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
-- const int output_w = (width + 2 * pad_w -
-- (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
-- const int channel_size = height * width;
-- for (int channel = channels; channel--; data_im += channel_size) {
-- for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
-- for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
-- int input_row = -pad_h + kernel_row * dilation_h;
-- for (int output_rows = output_h; output_rows; output_rows--) {
-- if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
-- data_col += output_w;
-- } else {
-- int input_col = -pad_w + kernel_col * dilation_w;
-- for (int output_col = output_w; output_col; output_col--) {
-- if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
-- data_im[input_row * width + input_col] += *data_col;
-- }
-- data_col++;
-- input_col += stride_w;
-- }
-- }
-- input_row += stride_h;
-- }
-- }
-- }
-- }
-- }
-- let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
-- r = rows m
-- mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
-- in parMap rseq (col2imUnsafe krows kcols srows scols drows dcols) mats
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
let columnMatrixRows = rows columnMatrix
dataIm <- U.newMatrix 0 destinationRows destinationCols
offsetR <- newSTRef 0
offsetC <- newSTRef 0
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
inputColumn <- newSTRef 0
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
ic <- readSTRef inputColumn
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir,ic))
modifySTRef inputColumn (+1)
offsetC' <- readSTRef offsetC
if offsetC' + kernelColumns < destinationCols
then modifySTRef offsetC (+ strideColumns)
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
return dataIm
col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do
let columnMatrixRows = rows columnMatrix
let filters = cols columnMatrix `div` (kernelRows * kernelColumns)
dataIms <- traverse (\_ -> U.newMatrix 0 destinationRows destinationCols) [0 .. filters-1]
offsetR <- newSTRef 0
offsetC <- newSTRef 0
offsetM <- newSTRef 0
forM_ dataIms $ \dataIm -> do
offsetM' <- readSTRef offsetM
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
inputColumn <- newSTRef 0
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
ic <- readSTRef inputColumn
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir, ic + offsetM'))
modifySTRef inputColumn (+1)
offsetC' <- readSTRef offsetC
if offsetC' + kernelColumns < destinationCols
then modifySTRef offsetC (+ strideColumns)
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
writeSTRef offsetR 0
writeSTRef offsetC 0
modifySTRef offsetM (+ (kernelRows * kernelColumns))
traverse U.freezeMatrix dataIms
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
vid2colUnsafe channels kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
matWidth = kernelRows * kernelColumns
destinationRows = 1 + (vidrows - kernelRows) `div` striderows
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols
destinationSize = destinationRows * destinationCols
dataCol <- U.newMatrix 0 destinationSize (channels * matWidth)
offsetC <- newSTRef 0
forM_ dataVid $ \dataIm -> do
inputRow <- newSTRef 0
offsetC' <- readSTRef offsetC
forM_ starts $ \(startRow, startCol) -> do
inputColumn <- newSTRef 0
inputRow' <- readSTRef inputRow
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
inputColumn' <- readSTRef inputColumn
U.modifyMatrix dataCol inputRow' (inputColumn' + offsetC') (+ atIndex dataIm (kr + startRow, kc + startCol))
modifySTRef inputColumn (+1)
modifySTRef inputRow (+1)
modifySTRef offsetC (+ matWidth)
return dataCol
im2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
im2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataIm = U.runSTMatrix $ do
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
matWidth = kernelRows * kernelColumns
destinationRows = 1 + (vidrows - kernelRows) `div` striderows
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols
destinationSize = destinationRows * destinationCols
dataCol <- U.newMatrix 0 destinationSize matWidth
inputRow <- newSTRef 0
forM_ starts $ \(startRow, startCol) -> do
inputColumn <- newSTRef 0
inputRow' <- readSTRef inputRow
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
inputColumn' <- readSTRef inputColumn
U.modifyMatrix dataCol inputRow' inputColumn' (+ atIndex dataIm (kr + startRow, kc + startCol))
modifySTRef inputColumn (+1)
modifySTRef inputRow (+1)
return dataCol

View File

@ -58,7 +58,7 @@ instance ( KnownNat cropLeft
m = extract input m = extract input
r = subMatrix (cropt, cropl) (nrows, ncols) m r = subMatrix (cropt, cropl) (nrows, ncols) m
in S2D' . fromJust . create $ r in S2D' . fromJust . create $ r
runBackards _ _ (S2D' dEdy) = runBackwards _ _ (S2D' dEdy) =
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop) cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight) cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)

View File

@ -47,5 +47,5 @@ randomDropout rate = do
instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where
runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops
runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x
runBackards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops) runBackwards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops)
runBackards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x) runBackwards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x)

View File

@ -33,11 +33,11 @@ instance UpdateLayer FlattenLayer where
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
runBackards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y) runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where instance (KnownNat a, KnownNat x, KnownNat y, KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
runForwards _ (S3D' y) = S1D' . raiseShapeError . create . vjoin . vecToList . fmap (flatten . extract) $ y runForwards _ (S3D' y) = S1D' . raiseShapeError . create . vjoin . vecToList . fmap (flatten . extract) $ y
runBackards _ _ (S1D' o) = runBackwards _ _ (S1D' o) =
let x' = fromIntegral $ natVal (Proxy :: Proxy x) let x' = fromIntegral $ natVal (Proxy :: Proxy x)
y' = fromIntegral $ natVal (Proxy :: Proxy y) y' = fromIntegral $ natVal (Proxy :: Proxy y)
z' = fromIntegral $ natVal (Proxy :: Proxy z) z' = fromIntegral $ natVal (Proxy :: Proxy z)

View File

@ -52,7 +52,7 @@ instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o)
runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v) runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)
-- Run a backpropogation step for a full connected layer. -- Run a backpropogation step for a full connected layer.
runBackards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) = runBackwards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
let wB' = dEdy let wB' = dEdy
mm' = dEdy `outer` x mm' = dEdy `outer` x
-- calcluate derivatives for next step -- calcluate derivatives for next step

View File

@ -45,8 +45,8 @@ instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
let yInput :: S' h = runForwards x input let yInput :: S' h = runForwards x input
in runForwards y yInput in runForwards y yInput
runBackards (x :$$ y) input backGradient = runBackwards (x :$$ y) input backGradient =
let yInput :: S' h = runForwards x input let yInput :: S' h = runForwards x input
(y', yGrad) = runBackards y yInput backGradient (y', yGrad) = runBackwards y yInput backGradient
(x', xGrad) = runBackards x input yGrad (x', xGrad) = runBackwards x input yGrad
in ((x', y'), xGrad) in ((x', y'), xGrad)

View File

@ -29,15 +29,15 @@ instance UpdateLayer Logit where
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
runForwards _ (S1D' y) = S1D' (logistic y) runForwards _ (S1D' y) = S1D' (logistic y)
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy)) runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
runForwards _ (S2D' y) = S2D' (logistic y) runForwards _ (S2D' y) = S2D' (logistic y)
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy)) runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
runForwards _ (S3D' y) = S3D' (fmap logistic y) runForwards _ (S3D' y) = S3D' (fmap logistic y)
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy)) runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy))
logistic :: Floating a => a -> a logistic :: Floating a => a -> a

View File

@ -58,7 +58,7 @@ instance ( KnownNat padLeft
m = extract input m = extract input
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)] r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
in S2D' . fromJust . create $ r in S2D' . fromJust . create $ r
runBackards Pad _ (S2D' dEdy) = runBackwards Pad _ (S2D' dEdy) =
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
padt = fromIntegral $ natVal (Proxy :: Proxy padTop) padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows) nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)

View File

@ -24,7 +24,7 @@ import GHC.TypeLits
import Grenade.Core.Network import Grenade.Core.Network
import Grenade.Core.Shape import Grenade.Core.Shape
import Grenade.Core.Vector import Grenade.Core.Vector
import Grenade.Layers.Convolution import Grenade.Layers.Convolution.Internal
import Numeric.LinearAlgebra hiding (uniformSample) import Numeric.LinearAlgebra hiding (uniformSample)
import qualified Numeric.LinearAlgebra as LA import qualified Numeric.LinearAlgebra as LA
@ -75,7 +75,7 @@ instance ( KnownNat kernelRows
r = poolForward kx ky sx sy ox oy $ ex r = poolForward kx ky sx sy ox oy $ ex
rs = fromJust . create $ r rs = fromJust . create $ r
in S2D' $ rs in S2D' $ rs
runBackards Pooling (S2D' input) (S2D' dEdy) = runBackwards Pooling (S2D' input) (S2D' dEdy) =
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
@ -111,7 +111,7 @@ instance ( KnownNat kernelRows
r = poolForwardList kx ky sx sy ix iy ox oy ex r = poolForwardList kx ky sx sy ix iy ox oy ex
rs = fmap (fromJust . create) r rs = fmap (fromJust . create) r
in S3D' rs in S3D' rs
runBackards Pooling (S3D' input) (S3D' dEdy) = runBackwards Pooling (S3D' input) (S3D' dEdy) =
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns) iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)

View File

@ -31,7 +31,7 @@ instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
runForwards _ (S1D' y) = S1D' (relu y) runForwards _ (S1D' y) = S1D' (relu y)
where where
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a) relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy)) runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
where where
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1) relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
@ -39,7 +39,7 @@ instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
runForwards _ (S2D' y) = S2D' (relu y) runForwards _ (S2D' y) = S2D' (relu y)
where where
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a) relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy)) runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
where where
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1) relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
@ -47,6 +47,6 @@ instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j
runForwards _ (S3D' y) = S3D' (fmap relu y) runForwards _ (S3D' y) = S3D' (fmap relu y)
where where
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a) relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy)) runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy))
where where
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1) relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)

View File

@ -26,15 +26,15 @@ instance UpdateLayer Tanh where
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
runForwards _ (S1D' y) = S1D' (tanh y) runForwards _ (S1D' y) = S1D' (tanh y)
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy)) runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
runForwards _ (S2D' y) = S2D' (tanh y) runForwards _ (S2D' y) = S2D' (tanh y)
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy)) runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
runForwards _ (S3D' y) = S3D' (fmap tanh y) runForwards _ (S3D' y) = S3D' (fmap tanh y)
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy)) runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy))
tanh' :: (Floating a) => a -> a tanh' :: (Floating a) => a -> a
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t tanh' t = 1 - s ^ (2 :: Int) where s = tanh t

View File

@ -8,6 +8,7 @@ import Grenade.Core.Shape
import Grenade.Core.Vector as Grenade import Grenade.Core.Vector as Grenade
import Grenade.Core.Network import Grenade.Core.Network
import Grenade.Layers.Convolution import Grenade.Layers.Convolution
import Grenade.Layers.Convolution.Internal
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===)) import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
import qualified Numeric.LinearAlgebra.Static as HStatic import qualified Numeric.LinearAlgebra.Static as HStatic
@ -63,6 +64,17 @@ prop_im2col_sym_on_same_stride = once $
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
in input === out in input === out
-- If there's no overlap (stride is the same size as the kernel)
-- then col2im . im2col should be symmetric.
prop_im2colunsafe_sym_on_same_stride = once $
let input = (3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 3 4 $ input
in input === out
-- If there is an overlap, then the gradient passed back should be -- If there is an overlap, then the gradient passed back should be
-- the sum of the gradients across the filters. -- the sum of the gradients across the filters.
prop_im2col_col2im_additive = once $ prop_im2col_col2im_additive = once $
@ -127,7 +139,7 @@ prop_simple_conv_forwards = once $
expectBack = (HStatic.matrix expectBack = (HStatic.matrix
[ 1.0, 0.0, 0.0 [ 1.0, 0.0, 0.0
, 0.0, -2.0,-1.0] :: HStatic.L 2 3) , 0.0, -2.0,-1.0] :: HStatic.L 2 3)
(nc, inX) = runBackards convLayer input grad (nc, inX) = runBackwards convLayer input grad
in case (out, inX, nc) of in case (out, inX, nc) of
(S3D' out' , S2D' inX', Convolution' backGrad) (S3D' out' , S2D' inX', Convolution' backGrad)
@ -187,6 +199,19 @@ prop_vid2col_invert = once $
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
in input === out in input === out
prop_vid2col_invert_unsafe = once $
let input = [(3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
, (3><4)
[ 21.0, 22.0, 23.0, 24.0
, 25.0, 26.0, 27.0, 28.0
, 29.0, 30.0, 31.0, 32.0 ] ]
out = col2vidUnsafe 3 2 3 2 3 4 . vid2colUnsafe 2 3 2 3 2 3 4 $ input
in input === out
-- This test show that 2D convs act the same -- This test show that 2D convs act the same
-- 3D convs with one layer -- 3D convs with one layer
prop_single_conv_forwards = once $ prop_single_conv_forwards = once $
@ -239,7 +264,7 @@ prop_single_conv_forwards = once $
expectBack = (HStatic.matrix expectBack = (HStatic.matrix
[ 1.0, 0.0, 0.0 [ 1.0, 0.0, 0.0
, 0.0, -2.0,-1.0] :: HStatic.L 2 3) , 0.0, -2.0,-1.0] :: HStatic.L 2 3)
(nc, inX) = runBackards convLayer input grad (nc, inX) = runBackwards convLayer input grad
in case (out, inX, nc) of in case (out, inX, nc) of
(S3D' out' , S3D' inX', Convolution' backGrad) (S3D' out' , S3D' inX', Convolution' backGrad)