mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Optimise poolForwardFit
This commit is contained in:
parent
b914b91c1f
commit
267ebc2080
@ -27,7 +27,7 @@ poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||
|
||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForwardFit starts nrows ncols _ outputCols m =
|
||||
let els = fmap (\start -> unsafeMaxElement $ subMatrix start (nrows, ncols) m) starts
|
||||
let els = fmap (\start -> unsafeMaxElementSubmatrix start (nrows, ncols) m) starts
|
||||
in LA.matrix outputCols els
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
@ -42,16 +42,6 @@ poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||
let starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||
|
||||
-- poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
-- poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
-- let inRows = rows inputMatrix
|
||||
-- inCols = cols inputMatrix
|
||||
-- inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
-- grads = toList $ flatten gradientMatrix
|
||||
-- grads' = zip3 starts grads inds
|
||||
-- accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
-- in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
||||
|
||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
|
||||
let inRows = rows inputMatrix
|
||||
@ -67,16 +57,8 @@ poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $
|
||||
|
||||
return retM
|
||||
|
||||
unsafeMaxElement :: Matrix Double -> Double
|
||||
unsafeMaxElement m = uncurry (U.atM' m) $ unsafeMaxIndex m
|
||||
|
||||
unsafeMaxIndex :: Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndex m =
|
||||
let mrows = [0 .. rows m - 1]
|
||||
mcols = [0 .. cols m - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
||||
|
||||
unsafeMaxElementSubmatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> Double
|
||||
unsafeMaxElementSubmatrix starts extent m = uncurry (U.atM' m) $ unsafeMaxIndexSubMatrix starts extent m
|
||||
|
||||
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =
|
||||
|
Loading…
Reference in New Issue
Block a user