mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Add faster maxIndex and maxElement for matrices
This commit is contained in:
parent
ff20855676
commit
24289dba41
@ -5,8 +5,13 @@ module Grenade.Layers.Internal.Pooling (
|
||||
, poolBackwardList
|
||||
) where
|
||||
|
||||
import Data.Function ( on )
|
||||
import Data.List ( maximumBy )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
@ -22,7 +27,7 @@ poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||
|
||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForwardFit starts nrows ncols _ outputCols m =
|
||||
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
|
||||
let els = fmap (\start -> unsafeMaxElement $ subMatrix start (nrows, ncols) m) starts
|
||||
in LA.matrix outputCols els
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
@ -41,8 +46,18 @@ poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
grads = toList $ flatten gradientMatrix
|
||||
grads' = zip3 starts grads inds
|
||||
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
||||
|
||||
unsafeMaxElement :: Matrix Double -> Double
|
||||
unsafeMaxElement m = uncurry (U.atM' m) $ unsafeMaxIndex m
|
||||
|
||||
unsafeMaxIndex :: Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndex m =
|
||||
let mrows = [0 .. rows m - 1]
|
||||
mcols = [0 .. cols m - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
||||
|
@ -12,8 +12,6 @@
|
||||
|
||||
module Grenade.Layers.Pooling (
|
||||
Pooling (..)
|
||||
, poolForward
|
||||
, poolBackward
|
||||
) where
|
||||
|
||||
import Data.Maybe
|
||||
|
@ -93,22 +93,22 @@ prop_simple_conv_forwards = once $
|
||||
-- Create a convolution kernel with 4 filters.
|
||||
-- [ 1, 0 [ 0, 1 [ 0, 1 [ 0, 0
|
||||
-- , 0,-1 ] ,-1, 0 ] , 1, 0 ] ,-1,-1 ]
|
||||
let myKernel = (HStatic.matrix
|
||||
let myKernel = HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 1.0, 1.0, 0.0
|
||||
, 0.0, -1.0, 1.0, -1.0
|
||||
,-1.0, 0.0, 0.0, -1.0 ] :: HStatic.L 4 4)
|
||||
zeroKernel = (HStatic.matrix
|
||||
,-1.0, 0.0, 0.0, -1.0 ] :: HStatic.L 4 4
|
||||
zeroKernel = HStatic.matrix
|
||||
[ 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4)
|
||||
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4
|
||||
|
||||
expectedGradient = (HStatic.matrix
|
||||
expectedGradient = HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0, 2.0
|
||||
, 2.0, 0.0, 0.0, 5.0
|
||||
, 3.0, 0.0, 0.0, 4.0
|
||||
, 4.0, 0.0, 0.0, 6.0 ] :: HStatic.L 4 4)
|
||||
, 4.0, 0.0, 0.0, 6.0 ] :: HStatic.L 4 4
|
||||
|
||||
convLayer = Convolution myKernel zeroKernel :: Convolution 1 4 2 2 1 1
|
||||
|
||||
@ -116,25 +116,25 @@ prop_simple_conv_forwards = once $
|
||||
[ 1.0, 2.0, 5.0
|
||||
, 3.0, 4.0, 6.0] :: HStatic.L 2 3)
|
||||
|
||||
expect = ([(HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2)]) :: [HStatic.L 1 2]
|
||||
expect = [ HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2] :: [HStatic.L 1 2]
|
||||
out = runForwards convLayer input :: S' ('D3 1 2 4)
|
||||
|
||||
grad = S3D' ( mkVector
|
||||
[(HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2)] ) :: S' ('D3 1 2 4)
|
||||
[ HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2] ) :: S' ('D3 1 2 4)
|
||||
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
@ -144,11 +144,8 @@ prop_simple_conv_forwards = once $
|
||||
in case (out, inX, nc) of
|
||||
(S3D' out' , S2D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
.&&. ((HStatic.extract expectBack) === (HStatic.extract inX'))
|
||||
.&&. ((HStatic.extract expectedGradient) === (HStatic.extract backGrad))
|
||||
-- Temporarily disabled, as l2 adjustment puts in off 5%
|
||||
-- .&&. HStatic.extract expectedKernel === HStatic.extract kernel'
|
||||
|
||||
.&&. (HStatic.extract expectBack === HStatic.extract inX')
|
||||
.&&. (HStatic.extract expectedGradient === HStatic.extract backGrad)
|
||||
|
||||
prop_vid2col_no_stride = once $
|
||||
let input = [(3><4)
|
||||
@ -228,25 +225,25 @@ prop_single_conv_forwards = once $
|
||||
[ 1.0, 2.0, 5.0
|
||||
, 3.0, 4.0, 6.0] :: HStatic.L 2 3] ) :: S' ('D3 2 3 1)
|
||||
|
||||
expect = ([(HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2)]) :: [HStatic.L 1 2]
|
||||
expect = [HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2] :: [HStatic.L 1 2]
|
||||
out = runForwards convLayer input :: S' ('D3 1 2 4)
|
||||
|
||||
grad = S3D' ( mkVector
|
||||
[(HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2)] ) :: S' ('D3 1 2 4)
|
||||
[HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2] ) :: S' ('D3 1 2 4)
|
||||
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
@ -257,7 +254,7 @@ prop_single_conv_forwards = once $
|
||||
(S3D' out' , S3D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
.&&. ([HStatic.extract expectBack] === (HStatic.extract <$> vecToList inX'))
|
||||
.&&. ((HStatic.extract expectedGradient) === (HStatic.extract backGrad))
|
||||
.&&. (HStatic.extract expectedGradient === HStatic.extract backGrad)
|
||||
|
||||
return []
|
||||
tests :: IO Bool
|
||||
|
@ -4,7 +4,7 @@
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Layers.Pooling where
|
||||
|
||||
import Grenade.Layers.Pooling
|
||||
import Grenade.Layers.Internal.Pooling
|
||||
|
||||
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||
|
||||
@ -12,11 +12,11 @@ import Test.QuickCheck hiding ((><))
|
||||
|
||||
prop_pool = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
[ 1.0, 14.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
expected = (2><3)
|
||||
[ 6.0, 7.0, 8.0
|
||||
[ 14.0, 14.0, 8.0
|
||||
, 10.0, 11.0, 12.0 ]
|
||||
out = poolForward 2 2 1 1 2 3 input
|
||||
in expected === out
|
||||
|
Loading…
Reference in New Issue
Block a user