Remove old functions

This commit is contained in:
Huw Campbell 2016-12-08 10:03:29 +11:00
parent 670f2d952f
commit 2433e1fba2
2 changed files with 45 additions and 91 deletions

View File

@ -1,88 +1,23 @@
module Grenade.Layers.Internal.Convolution ( module Grenade.Layers.Internal.Convolution (
im2col col2vidUnsafe
-- , im2colUnsafe
, vid2col
, col2im
, col2imFit
, col2vid
, col2vidUnsafe
, col2imUnsafe , col2imUnsafe
, im2colUnsafe
, vid2colUnsafe , vid2colUnsafe
, im2colUnsafe
, fittingStarts , fittingStarts
) where ) where
import Control.Monad.ST import Control.Monad.ST ( runST )
import Control.Parallel.Strategies ( parMap, rseq )
import Data.STRef import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef )
import Data.Foldable ( forM_ ) import Data.Foldable ( forM_ )
import Numeric.LinearAlgebra hiding ( uniformSample, konst ) import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra as LA
import qualified Numeric.LinearAlgebra.Devel as U import qualified Numeric.LinearAlgebra.Devel as U
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -- This module provides provides im2col function and friends, ala caffe.
im2col nrows ncols srows scols m = --
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols -- /* From Caffe */
in im2colFit starts nrows ncols m -- @
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
im2colFit starts nrows ncols m =
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
in fromRows imRows
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
vid2col nrows ncols srows scols inputrows inputcols ms =
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
subs = parMap rseq (im2colFit starts nrows ncols) ms
in foldl1 (|||) subs
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
col2vid krows kcols srows scols drows dcols m =
let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
r = rows m
mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
colSts = fittingStarts drows krows srows dcols kcols scols
in parMap rseq (col2imFit colSts krows kcols drows dcols) mats
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2im krows kcols srows scols drows dcols m =
let starts = fittingStarts drows krows srows dcols kcols scols
in col2imFit starts krows kcols drows dcols m
-- | These functions are not even remotely safe, but it's only called from the statically typed
-- commands, so we should be good ?!?!?
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
ls = sequence [rs, cs]
in fmap (\[a,b] -> (a,b)) ls
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= [left]
| otherwise
= [] -- error "Kernel and step do not fit in matrix."
in go 0
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imFit starts krows kcols drows dcols m =
let indicies = (\[a,b] -> (a,b)) <$> sequence [[0..(krows-1)], [0..(kcols-1)]]
convs = fmap (zip indicies . toList) . toRows $ m
pairs = zip convs starts
accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
in accum (LA.konst 0 (drows, dcols)) (+) accums
-- void col2im_cpu(const Dtype* data_col, const int channels, -- void col2im_cpu(const Dtype* data_col, const int channels,
-- const int height, const int width, const int kernel_h, const int kernel_w, -- const int height, const int width, const int kernel_h, const int kernel_w,
-- const int pad_h, const int pad_w, -- const int pad_h, const int pad_w,
@ -118,14 +53,19 @@ col2imFit starts krows kcols drows dcols m =
-- } -- }
-- } -- }
-- } -- }
-- @
--
-- | col2im function.
--
-- Takes a column patch, and reconstitutes it into a normal image.
-- Does not do any bounds checking on the matrix, so should only
-- be called once the sizes are ensured correct.
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
let columnMatrixRows = rows columnMatrix let columnMatrixRows = rows columnMatrix
dataIm <- U.newMatrix 0 destinationRows destinationCols dataIm <- U.newMatrix 0 destinationRows destinationCols
offsetR <- newSTRef 0 offsetR <- newSTRef 0
offsetC <- newSTRef 0 offsetC <- newSTRef 0
@ -146,6 +86,11 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
return dataIm return dataIm
-- | col2vid function.
--
-- Takes a column patch image, and reconstitutes it into a normal image with multiple channels.
-- Does not do any bounds checking on the matrix, so should only
-- be called once the sizes are ensured correct.
col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double] col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do
let columnMatrixRows = rows columnMatrix let columnMatrixRows = rows columnMatrix
@ -226,3 +171,25 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr
modifySTRef inputRowRef (+1) modifySTRef inputRowRef (+1)
return dataCol return dataCol
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
ls = sequence [rs, cs]
in fmap (\[a,b] -> (a,b)) ls
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= [left]
| otherwise
= []
in go 0

View File

@ -27,7 +27,7 @@ prop_im2col_no_stride = once $
, 5.0, 6.0, 9.0, 10.0 , 5.0, 6.0, 9.0, 10.0
, 6.0, 7.0, 10.0, 11.0 , 6.0, 7.0, 10.0, 11.0
, 7.0, 8.0, 11.0, 12.0 ] , 7.0, 8.0, 11.0, 12.0 ]
out = im2col 2 2 1 1 input out = im2colUnsafe 2 2 1 1 input
in expected === out in expected === out
prop_im2col_stride = once $ prop_im2col_stride = once $
@ -40,7 +40,7 @@ prop_im2col_stride = once $
, 3.0, 4.0, 7.0, 8.0 , 3.0, 4.0, 7.0, 8.0
, 5.0, 6.0, 9.0, 10.0 , 5.0, 6.0, 9.0, 10.0
, 7.0, 8.0, 11.0, 12.0 ] , 7.0, 8.0, 11.0, 12.0 ]
out = im2col 2 2 1 2 input out = im2colUnsafe 2 2 1 2 input
in expected === out in expected === out
prop_im2col_other = once $ prop_im2col_other = once $
@ -61,7 +61,7 @@ prop_im2col_sym_on_same_stride = once $
[ 1.0, 2.0, 3.0, 4.0 [ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0 , 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ] , 9.0, 10.0, 11.0, 12.0 ]
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
in input === out in input === out
-- If there's no overlap (stride is the same size as the kernel) -- If there's no overlap (stride is the same size as the kernel)
@ -186,20 +186,7 @@ prop_vid2col_stride = once $
out = vid2colUnsafe 2 2 1 2 3 4 input out = vid2colUnsafe 2 2 1 2 3 4 input
in expected === out in expected === out
prop_vid2col_invert = once $ prop_vid2col_invert = once $
let input = [(3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
, (3><4)
[ 21.0, 22.0, 23.0, 24.0
, 25.0, 26.0, 27.0, 28.0
, 29.0, 30.0, 31.0, 32.0 ] ]
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
in input === out
prop_vid2col_invert_unsafe = once $
let input = [(3><4) let input = [(3><4)
[ 1.0, 2.0, 3.0, 4.0 [ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0 , 5.0, 6.0, 7.0, 8.0