Make ST modifications strict

This commit is contained in:
Huw Campbell 2016-12-09 20:34:50 +11:00
parent 8b3ca1e0b6
commit bbd29e71bd

View File

@ -4,14 +4,17 @@ module Grenade.Layers.Internal.Convolution (
, vid2colUnsafe
, im2colUnsafe
, fittingStarts
, unsafeModifyMatrix
) where
import Control.Monad.ST ( runST )
import Control.Monad.ST ( ST, runST )
import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef )
import Data.STRef ( newSTRef, modifySTRef', writeSTRef, readSTRef )
import Data.Foldable ( forM_ )
import Data.Traversable ( forM )
import Foreign.Storable( Storable )
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra.Devel as U
@ -72,18 +75,17 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
forM_ [0 .. columnMatrixRows - 1] $ \inputRow -> do
inputColumnRef <- newSTRef 0
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
inputColumn <- readSTRef inputColumnRef
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix inputRow inputColumn)
modifySTRef inputColumnRef (+1)
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
inputColumn <- readSTRef inputColumnRef
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix inputRow inputColumn)
modifySTRef' inputColumnRef (+1)
if offsetC' + kernelColumns < destinationCols
then modifySTRef offsetC (+ strideColumns)
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
then modifySTRef' offsetC (+ strideColumns)
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
return dataIm
@ -104,18 +106,17 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
offsetC <- newSTRef 0
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
inputColumn <- newSTRef 0
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
ic <- readSTRef inputColumn
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix ir (ic + offsetM))
modifySTRef inputColumn (+1)
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
ic <- readSTRef inputColumn
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix ir (ic + offsetM))
modifySTRef' inputColumn (+1)
if offsetC' + kernelColumns < destinationCols
then modifySTRef offsetC (+ strideColumns)
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
then modifySTRef' offsetC (+ strideColumns)
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
U.unsafeFreezeMatrix dataIm
@ -136,14 +137,14 @@ vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dat
forM_ starts $ \(startRow, startCol) -> do
inputColumnRef <- newSTRef 0
inputRow <- readSTRef inputRowRef
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
inputColumn <- readSTRef inputColumnRef
U.modifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm (kr + startRow) (kc + startCol))
modifySTRef inputColumnRef (+1)
modifySTRef inputRowRef (+1)
unsafeModifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm kr kc)
modifySTRef' inputColumnRef (+1)
modifySTRef' inputRowRef (+1)
modifySTRef offsetC (+ kernelSize)
modifySTRef' offsetC (+ kernelSize)
return dataCol
@ -159,15 +160,19 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr
forM_ starts $ \(startRow, startCol) -> do
inputColumnRef <- newSTRef 0
inputRow <- readSTRef inputRowRef
forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
inputColumn <- readSTRef inputColumnRef
U.modifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm (kr + startRow) (kc + startCol))
modifySTRef inputColumnRef (+1)
modifySTRef inputRowRef (+1)
unsafeModifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm kr kc)
modifySTRef' inputColumnRef (+1)
modifySTRef' inputRowRef (+1)
return dataCol
unsafeModifyMatrix :: (Storable t) => U.STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
{-# INLINE unsafeModifyMatrix #-}
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.