mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Reimplement pooling backwards for speed
This commit is contained in:
parent
bbd29e71bd
commit
b914b91c1f
@ -5,6 +5,7 @@ module Grenade.Layers.Internal.Pooling (
|
||||
, poolBackwardList
|
||||
) where
|
||||
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Function ( on )
|
||||
import Data.List ( maximumBy )
|
||||
|
||||
@ -12,7 +13,6 @@ import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
@ -42,15 +42,30 @@ poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||
let starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||
|
||||
-- poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
-- poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
-- let inRows = rows inputMatrix
|
||||
-- inCols = cols inputMatrix
|
||||
-- inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
-- grads = toList $ flatten gradientMatrix
|
||||
-- grads' = zip3 starts grads inds
|
||||
-- accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
-- in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
||||
|
||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
grads = toList $ flatten gradientMatrix
|
||||
grads' = zip3 starts grads inds
|
||||
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
||||
gradCol = cols gradientMatrix
|
||||
extent = (krows, kcols)
|
||||
|
||||
retM <- U.newMatrix 0 inRows inCols
|
||||
|
||||
forM_ (zip [0..] starts) $ \(ix, start) -> do
|
||||
let loc = unsafeMaxIndexSubMatrix start extent inputMatrix
|
||||
uncurry (unsafeModifyMatrix retM) loc ((+) $ uncurry (U.atM' gradientMatrix) $ divMod ix gradCol)
|
||||
|
||||
return retM
|
||||
|
||||
unsafeMaxElement :: Matrix Double -> Double
|
||||
unsafeMaxElement m = uncurry (U.atM' m) $ unsafeMaxIndex m
|
||||
@ -61,3 +76,12 @@ unsafeMaxIndex m =
|
||||
mcols = [0 .. cols m - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
||||
|
||||
|
||||
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =
|
||||
let mrows = [startRow .. startRow + extentRow - 1]
|
||||
mcols = [startCol .. startCol + extentCold - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user