mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Remove old functions
This commit is contained in:
parent
670f2d952f
commit
2433e1fba2
@ -1,88 +1,23 @@
|
||||
module Grenade.Layers.Internal.Convolution (
|
||||
im2col
|
||||
-- , im2colUnsafe
|
||||
, vid2col
|
||||
, col2im
|
||||
, col2imFit
|
||||
, col2vid
|
||||
|
||||
, col2vidUnsafe
|
||||
col2vidUnsafe
|
||||
, col2imUnsafe
|
||||
, im2colUnsafe
|
||||
, vid2colUnsafe
|
||||
, im2colUnsafe
|
||||
, fittingStarts
|
||||
) where
|
||||
|
||||
import Control.Monad.ST
|
||||
import Control.Parallel.Strategies ( parMap, rseq )
|
||||
import Control.Monad.ST ( runST )
|
||||
|
||||
import Data.STRef
|
||||
import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef )
|
||||
import Data.Foldable ( forM_ )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
im2col nrows ncols srows scols m =
|
||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||
in im2colFit starts nrows ncols m
|
||||
|
||||
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
im2colFit starts nrows ncols m =
|
||||
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
|
||||
in fromRows imRows
|
||||
|
||||
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||
vid2col nrows ncols srows scols inputrows inputcols ms =
|
||||
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
|
||||
subs = parMap rseq (im2colFit starts nrows ncols) ms
|
||||
in foldl1 (|||) subs
|
||||
|
||||
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||
col2vid krows kcols srows scols drows dcols m =
|
||||
let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
|
||||
r = rows m
|
||||
mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
|
||||
colSts = fittingStarts drows krows srows dcols kcols scols
|
||||
in parMap rseq (col2imFit colSts krows kcols drows dcols) mats
|
||||
|
||||
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2im krows kcols srows scols drows dcols m =
|
||||
let starts = fittingStarts drows krows srows dcols kcols scols
|
||||
in col2imFit starts krows kcols drows dcols m
|
||||
|
||||
-- | These functions are not even remotely safe, but it's only called from the statically typed
|
||||
-- commands, so we should be good ?!?!?
|
||||
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
||||
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
||||
let rs = fittingStart nrows kernelrows steprows
|
||||
cs = fittingStart ncols kernelcols stepcolsh
|
||||
ls = sequence [rs, cs]
|
||||
in fmap (\[a,b] -> (a,b)) ls
|
||||
|
||||
-- | Returns the starting sub vector which fit inside the larger vector for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStart :: Int -> Int -> Int -> [Int]
|
||||
fittingStart width kernel steps =
|
||||
let go left | left + kernel < width
|
||||
= left : go (left + steps)
|
||||
| left + kernel == width
|
||||
= [left]
|
||||
| otherwise
|
||||
= [] -- error "Kernel and step do not fit in matrix."
|
||||
in go 0
|
||||
|
||||
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2imFit starts krows kcols drows dcols m =
|
||||
let indicies = (\[a,b] -> (a,b)) <$> sequence [[0..(krows-1)], [0..(kcols-1)]]
|
||||
convs = fmap (zip indicies . toList) . toRows $ m
|
||||
pairs = zip convs starts
|
||||
accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
|
||||
in accum (LA.konst 0 (drows, dcols)) (+) accums
|
||||
|
||||
-- This module provides provides im2col function and friends, ala caffe.
|
||||
--
|
||||
-- /* From Caffe */
|
||||
-- @
|
||||
-- void col2im_cpu(const Dtype* data_col, const int channels,
|
||||
-- const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
-- const int pad_h, const int pad_w,
|
||||
@ -118,14 +53,19 @@ col2imFit starts krows kcols drows dcols m =
|
||||
-- }
|
||||
-- }
|
||||
-- }
|
||||
-- @
|
||||
--
|
||||
|
||||
-- | col2im function.
|
||||
--
|
||||
-- Takes a column patch, and reconstitutes it into a normal image.
|
||||
-- Does not do any bounds checking on the matrix, so should only
|
||||
-- be called once the sizes are ensured correct.
|
||||
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
|
||||
|
||||
let columnMatrixRows = rows columnMatrix
|
||||
|
||||
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
||||
|
||||
offsetR <- newSTRef 0
|
||||
offsetC <- newSTRef 0
|
||||
|
||||
@ -146,6 +86,11 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
|
||||
|
||||
return dataIm
|
||||
|
||||
-- | col2vid function.
|
||||
--
|
||||
-- Takes a column patch image, and reconstitutes it into a normal image with multiple channels.
|
||||
-- Does not do any bounds checking on the matrix, so should only
|
||||
-- be called once the sizes are ensured correct.
|
||||
col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||
col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do
|
||||
let columnMatrixRows = rows columnMatrix
|
||||
@ -226,3 +171,25 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr
|
||||
modifySTRef inputRowRef (+1)
|
||||
|
||||
return dataCol
|
||||
|
||||
|
||||
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
||||
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
||||
let rs = fittingStart nrows kernelrows steprows
|
||||
cs = fittingStart ncols kernelcols stepcolsh
|
||||
ls = sequence [rs, cs]
|
||||
in fmap (\[a,b] -> (a,b)) ls
|
||||
|
||||
-- | Returns the starting sub vector which fit inside the larger vector for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStart :: Int -> Int -> Int -> [Int]
|
||||
fittingStart width kernel steps =
|
||||
let go left | left + kernel < width
|
||||
= left : go (left + steps)
|
||||
| left + kernel == width
|
||||
= [left]
|
||||
| otherwise
|
||||
= []
|
||||
in go 0
|
||||
|
@ -27,7 +27,7 @@ prop_im2col_no_stride = once $
|
||||
, 5.0, 6.0, 9.0, 10.0
|
||||
, 6.0, 7.0, 10.0, 11.0
|
||||
, 7.0, 8.0, 11.0, 12.0 ]
|
||||
out = im2col 2 2 1 1 input
|
||||
out = im2colUnsafe 2 2 1 1 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_stride = once $
|
||||
@ -40,7 +40,7 @@ prop_im2col_stride = once $
|
||||
, 3.0, 4.0, 7.0, 8.0
|
||||
, 5.0, 6.0, 9.0, 10.0
|
||||
, 7.0, 8.0, 11.0, 12.0 ]
|
||||
out = im2col 2 2 1 2 input
|
||||
out = im2colUnsafe 2 2 1 2 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_other = once $
|
||||
@ -61,7 +61,7 @@ prop_im2col_sym_on_same_stride = once $
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
|
||||
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
|
||||
in input === out
|
||||
|
||||
-- If there's no overlap (stride is the same size as the kernel)
|
||||
@ -186,20 +186,7 @@ prop_vid2col_stride = once $
|
||||
out = vid2colUnsafe 2 2 1 2 3 4 input
|
||||
in expected === out
|
||||
|
||||
|
||||
prop_vid2col_invert = once $
|
||||
let input = [(3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
, (3><4)
|
||||
[ 21.0, 22.0, 23.0, 24.0
|
||||
, 25.0, 26.0, 27.0, 28.0
|
||||
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
|
||||
in input === out
|
||||
|
||||
prop_vid2col_invert_unsafe = once $
|
||||
let input = [(3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
|
Loading…
Reference in New Issue
Block a user