Optimise poolForwardFit

This commit is contained in:
Huw Campbell 2016-12-09 20:51:28 +11:00
parent b914b91c1f
commit 267ebc2080

View File

@ -27,7 +27,7 @@ poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForwardFit starts nrows ncols _ outputCols m =
let els = fmap (\start -> unsafeMaxElement $ subMatrix start (nrows, ncols) m) starts
let els = fmap (\start -> unsafeMaxElementSubmatrix start (nrows, ncols) m) starts
in LA.matrix outputCols els
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
@ -42,16 +42,6 @@ poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
let starts = fittingStarts inRows krows srows inCols kcols scols
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
-- poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
-- poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
-- let inRows = rows inputMatrix
-- inCols = cols inputMatrix
-- inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
-- grads = toList $ flatten gradientMatrix
-- grads' = zip3 starts grads inds
-- accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
-- in accum (LA.konst 0 (inRows, inCols)) (+) accums
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
let inRows = rows inputMatrix
@ -67,16 +57,8 @@ poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $
return retM
unsafeMaxElement :: Matrix Double -> Double
unsafeMaxElement m = uncurry (U.atM' m) $ unsafeMaxIndex m
unsafeMaxIndex :: Matrix Double -> (Int, Int)
unsafeMaxIndex m =
let mrows = [0 .. rows m - 1]
mcols = [0 .. cols m - 1]
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
unsafeMaxElementSubmatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> Double
unsafeMaxElementSubmatrix starts extent m = uncurry (U.atM' m) $ unsafeMaxIndexSubMatrix starts extent m
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =