mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Merge pull request #14 from HuwCampbell/topic/performance
Topic/performance
This commit is contained in:
commit
2cffa4a568
@ -4,14 +4,17 @@ module Grenade.Layers.Internal.Convolution (
|
||||
, vid2colUnsafe
|
||||
, im2colUnsafe
|
||||
, fittingStarts
|
||||
, unsafeModifyMatrix
|
||||
) where
|
||||
|
||||
import Control.Monad.ST ( runST )
|
||||
import Control.Monad.ST ( ST, runST )
|
||||
|
||||
import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef )
|
||||
import Data.STRef ( newSTRef, modifySTRef', writeSTRef, readSTRef )
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Traversable ( forM )
|
||||
|
||||
import Foreign.Storable( Storable )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
@ -72,18 +75,17 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
|
||||
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \inputRow -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix inputRow inputColumn)
|
||||
modifySTRef inputColumnRef (+1)
|
||||
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
|
||||
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix inputRow inputColumn)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||
then modifySTRef' offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
|
||||
|
||||
return dataIm
|
||||
|
||||
@ -104,18 +106,17 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
|
||||
offsetC <- newSTRef 0
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
||||
inputColumn <- newSTRef 0
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
ic <- readSTRef inputColumn
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix ir (ic + offsetM))
|
||||
modifySTRef inputColumn (+1)
|
||||
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
|
||||
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
|
||||
ic <- readSTRef inputColumn
|
||||
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix ir (ic + offsetM))
|
||||
modifySTRef' inputColumn (+1)
|
||||
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||
then modifySTRef' offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
|
||||
|
||||
U.unsafeFreezeMatrix dataIm
|
||||
|
||||
@ -136,14 +137,14 @@ vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dat
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
|
||||
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.modifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||
modifySTRef inputColumnRef (+1)
|
||||
modifySTRef inputRowRef (+1)
|
||||
unsafeModifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm kr kc)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
modifySTRef' inputRowRef (+1)
|
||||
|
||||
modifySTRef offsetC (+ kernelSize)
|
||||
modifySTRef' offsetC (+ kernelSize)
|
||||
|
||||
return dataCol
|
||||
|
||||
@ -159,15 +160,19 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
|
||||
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.modifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||
modifySTRef inputColumnRef (+1)
|
||||
modifySTRef inputRowRef (+1)
|
||||
unsafeModifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm kr kc)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
modifySTRef' inputRowRef (+1)
|
||||
|
||||
return dataCol
|
||||
|
||||
unsafeModifyMatrix :: (Storable t) => U.STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
|
||||
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
|
||||
{-# INLINE unsafeModifyMatrix #-}
|
||||
|
||||
|
||||
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
|
@ -5,6 +5,7 @@ module Grenade.Layers.Internal.Pooling (
|
||||
, poolBackwardList
|
||||
) where
|
||||
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Function ( on )
|
||||
import Data.List ( maximumBy )
|
||||
|
||||
@ -12,7 +13,6 @@ import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
@ -27,7 +27,7 @@ poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||
|
||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForwardFit starts nrows ncols _ outputCols m =
|
||||
let els = fmap (\start -> unsafeMaxElement $ subMatrix start (nrows, ncols) m) starts
|
||||
let els = fmap (\start -> unsafeMaxElementSubmatrix start (nrows, ncols) m) starts
|
||||
in LA.matrix outputCols els
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
@ -43,21 +43,27 @@ poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||
|
||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
grads = toList $ flatten gradientMatrix
|
||||
grads' = zip3 starts grads inds
|
||||
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
||||
gradCol = cols gradientMatrix
|
||||
extent = (krows, kcols)
|
||||
|
||||
unsafeMaxElement :: Matrix Double -> Double
|
||||
unsafeMaxElement m = uncurry (U.atM' m) $ unsafeMaxIndex m
|
||||
retM <- U.newMatrix 0 inRows inCols
|
||||
|
||||
unsafeMaxIndex :: Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndex m =
|
||||
let mrows = [0 .. rows m - 1]
|
||||
mcols = [0 .. cols m - 1]
|
||||
forM_ (zip [0..] starts) $ \(ix, start) -> do
|
||||
let loc = unsafeMaxIndexSubMatrix start extent inputMatrix
|
||||
uncurry (unsafeModifyMatrix retM) loc ((+) $ uncurry (U.atM' gradientMatrix) $ divMod ix gradCol)
|
||||
|
||||
return retM
|
||||
|
||||
unsafeMaxElementSubmatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> Double
|
||||
unsafeMaxElementSubmatrix starts extent m = uncurry (U.atM' m) $ unsafeMaxIndexSubMatrix starts extent m
|
||||
|
||||
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =
|
||||
let mrows = [startRow .. startRow + extentRow - 1]
|
||||
mcols = [startCol .. startCol + extentCold - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user