mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Make ST modifications strict
This commit is contained in:
parent
8b3ca1e0b6
commit
bbd29e71bd
@ -4,14 +4,17 @@ module Grenade.Layers.Internal.Convolution (
|
||||
, vid2colUnsafe
|
||||
, im2colUnsafe
|
||||
, fittingStarts
|
||||
, unsafeModifyMatrix
|
||||
) where
|
||||
|
||||
import Control.Monad.ST ( runST )
|
||||
import Control.Monad.ST ( ST, runST )
|
||||
|
||||
import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef )
|
||||
import Data.STRef ( newSTRef, modifySTRef', writeSTRef, readSTRef )
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Traversable ( forM )
|
||||
|
||||
import Foreign.Storable( Storable )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
@ -72,18 +75,17 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
|
||||
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \inputRow -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix inputRow inputColumn)
|
||||
modifySTRef inputColumnRef (+1)
|
||||
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
|
||||
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix inputRow inputColumn)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||
then modifySTRef' offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
|
||||
|
||||
return dataIm
|
||||
|
||||
@ -104,18 +106,17 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
|
||||
offsetC <- newSTRef 0
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
||||
inputColumn <- newSTRef 0
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
ic <- readSTRef inputColumn
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix ir (ic + offsetM))
|
||||
modifySTRef inputColumn (+1)
|
||||
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
|
||||
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
|
||||
ic <- readSTRef inputColumn
|
||||
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix ir (ic + offsetM))
|
||||
modifySTRef' inputColumn (+1)
|
||||
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||
then modifySTRef' offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef' offsetR (+ strideRows)
|
||||
|
||||
U.unsafeFreezeMatrix dataIm
|
||||
|
||||
@ -136,14 +137,14 @@ vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dat
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
|
||||
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.modifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||
modifySTRef inputColumnRef (+1)
|
||||
modifySTRef inputRowRef (+1)
|
||||
unsafeModifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm kr kc)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
modifySTRef' inputRowRef (+1)
|
||||
|
||||
modifySTRef offsetC (+ kernelSize)
|
||||
modifySTRef' offsetC (+ kernelSize)
|
||||
|
||||
return dataCol
|
||||
|
||||
@ -159,15 +160,19 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
forM_ [startRow .. startRow + kernelRows -1] $ \kr ->
|
||||
forM_ [startCol .. startCol + kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.modifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||
modifySTRef inputColumnRef (+1)
|
||||
modifySTRef inputRowRef (+1)
|
||||
unsafeModifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm kr kc)
|
||||
modifySTRef' inputColumnRef (+1)
|
||||
modifySTRef' inputRowRef (+1)
|
||||
|
||||
return dataCol
|
||||
|
||||
unsafeModifyMatrix :: (Storable t) => U.STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
|
||||
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
|
||||
{-# INLINE unsafeModifyMatrix #-}
|
||||
|
||||
|
||||
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
|
Loading…
Reference in New Issue
Block a user