Reimplement pooling backwards for speed

This commit is contained in:
Huw Campbell 2016-12-09 20:35:04 +11:00
parent bbd29e71bd
commit b914b91c1f

View File

@ -5,6 +5,7 @@ module Grenade.Layers.Internal.Pooling (
, poolBackwardList
) where
import Data.Foldable ( forM_ )
import Data.Function ( on )
import Data.List ( maximumBy )
@ -12,7 +13,6 @@ import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra as LA
import qualified Numeric.LinearAlgebra.Devel as U
import Grenade.Layers.Internal.Convolution
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
@ -42,15 +42,30 @@ poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
let starts = fittingStarts inRows krows srows inCols kcols scols
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
-- poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
-- poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
-- let inRows = rows inputMatrix
-- inCols = cols inputMatrix
-- inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
-- grads = toList $ flatten gradientMatrix
-- grads' = zip3 starts grads inds
-- accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
-- in accum (LA.konst 0 (inRows, inCols)) (+) accums
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
let inRows = rows inputMatrix
inCols = cols inputMatrix
inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
grads = toList $ flatten gradientMatrix
grads' = zip3 starts grads inds
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
in accum (LA.konst 0 (inRows, inCols)) (+) accums
gradCol = cols gradientMatrix
extent = (krows, kcols)
retM <- U.newMatrix 0 inRows inCols
forM_ (zip [0..] starts) $ \(ix, start) -> do
let loc = unsafeMaxIndexSubMatrix start extent inputMatrix
uncurry (unsafeModifyMatrix retM) loc ((+) $ uncurry (U.atM' gradientMatrix) $ divMod ix gradCol)
return retM
unsafeMaxElement :: Matrix Double -> Double
unsafeMaxElement m = uncurry (U.atM' m) $ unsafeMaxIndex m
@ -61,3 +76,12 @@ unsafeMaxIndex m =
mcols = [0 .. cols m - 1]
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =
let mrows = [startRow .. startRow + extentRow - 1]
mcols = [startCol .. startCol + extentCold - 1]
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
in maximumBy (compare `on` uncurry (U.atM' m)) pairs