Remove old functions

This commit is contained in:
Huw Campbell 2016-12-08 10:03:29 +11:00
parent 670f2d952f
commit 2433e1fba2
2 changed files with 45 additions and 91 deletions

View File

@ -1,88 +1,23 @@
module Grenade.Layers.Internal.Convolution (
im2col
-- , im2colUnsafe
, vid2col
, col2im
, col2imFit
, col2vid
, col2vidUnsafe
col2vidUnsafe
, col2imUnsafe
, im2colUnsafe
, vid2colUnsafe
, im2colUnsafe
, fittingStarts
) where
import Control.Monad.ST
import Control.Parallel.Strategies ( parMap, rseq )
import Control.Monad.ST ( runST )
import Data.STRef
import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef )
import Data.Foldable ( forM_ )
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra as LA
import qualified Numeric.LinearAlgebra.Devel as U
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
im2col nrows ncols srows scols m =
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
in im2colFit starts nrows ncols m
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
im2colFit starts nrows ncols m =
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
in fromRows imRows
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
vid2col nrows ncols srows scols inputrows inputcols ms =
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
subs = parMap rseq (im2colFit starts nrows ncols) ms
in foldl1 (|||) subs
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
col2vid krows kcols srows scols drows dcols m =
let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
r = rows m
mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
colSts = fittingStarts drows krows srows dcols kcols scols
in parMap rseq (col2imFit colSts krows kcols drows dcols) mats
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2im krows kcols srows scols drows dcols m =
let starts = fittingStarts drows krows srows dcols kcols scols
in col2imFit starts krows kcols drows dcols m
-- | These functions are not even remotely safe, but it's only called from the statically typed
-- commands, so we should be good ?!?!?
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
ls = sequence [rs, cs]
in fmap (\[a,b] -> (a,b)) ls
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= [left]
| otherwise
= [] -- error "Kernel and step do not fit in matrix."
in go 0
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imFit starts krows kcols drows dcols m =
let indicies = (\[a,b] -> (a,b)) <$> sequence [[0..(krows-1)], [0..(kcols-1)]]
convs = fmap (zip indicies . toList) . toRows $ m
pairs = zip convs starts
accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
in accum (LA.konst 0 (drows, dcols)) (+) accums
-- This module provides provides im2col function and friends, ala caffe.
--
-- /* From Caffe */
-- @
-- void col2im_cpu(const Dtype* data_col, const int channels,
-- const int height, const int width, const int kernel_h, const int kernel_w,
-- const int pad_h, const int pad_w,
@ -118,14 +53,19 @@ col2imFit starts krows kcols drows dcols m =
-- }
-- }
-- }
-- @
--
-- | col2im function.
--
-- Takes a column patch, and reconstitutes it into a normal image.
-- Does not do any bounds checking on the matrix, so should only
-- be called once the sizes are ensured correct.
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
let columnMatrixRows = rows columnMatrix
dataIm <- U.newMatrix 0 destinationRows destinationCols
offsetR <- newSTRef 0
offsetC <- newSTRef 0
@ -146,6 +86,11 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
return dataIm
-- | col2vid function.
--
-- Takes a column patch image, and reconstitutes it into a normal image with multiple channels.
-- Does not do any bounds checking on the matrix, so should only
-- be called once the sizes are ensured correct.
col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do
let columnMatrixRows = rows columnMatrix
@ -226,3 +171,25 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr
modifySTRef inputRowRef (+1)
return dataCol
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
ls = sequence [rs, cs]
in fmap (\[a,b] -> (a,b)) ls
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= [left]
| otherwise
= []
in go 0

View File

@ -27,7 +27,7 @@ prop_im2col_no_stride = once $
, 5.0, 6.0, 9.0, 10.0
, 6.0, 7.0, 10.0, 11.0
, 7.0, 8.0, 11.0, 12.0 ]
out = im2col 2 2 1 1 input
out = im2colUnsafe 2 2 1 1 input
in expected === out
prop_im2col_stride = once $
@ -40,7 +40,7 @@ prop_im2col_stride = once $
, 3.0, 4.0, 7.0, 8.0
, 5.0, 6.0, 9.0, 10.0
, 7.0, 8.0, 11.0, 12.0 ]
out = im2col 2 2 1 2 input
out = im2colUnsafe 2 2 1 2 input
in expected === out
prop_im2col_other = once $
@ -61,7 +61,7 @@ prop_im2col_sym_on_same_stride = once $
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
in input === out
-- If there's no overlap (stride is the same size as the kernel)
@ -186,20 +186,7 @@ prop_vid2col_stride = once $
out = vid2colUnsafe 2 2 1 2 3 4 input
in expected === out
prop_vid2col_invert = once $
let input = [(3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
, (3><4)
[ 21.0, 22.0, 23.0, 24.0
, 25.0, 26.0, 27.0, 28.0
, 29.0, 30.0, 31.0, 32.0 ] ]
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
in input === out
prop_vid2col_invert_unsafe = once $
let input = [(3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0