diff --git a/src/Grenade/Layers/Internal/Convolution.hs b/src/Grenade/Layers/Internal/Convolution.hs index 336a9d3..7e5ad88 100644 --- a/src/Grenade/Layers/Internal/Convolution.hs +++ b/src/Grenade/Layers/Internal/Convolution.hs @@ -1,88 +1,23 @@ module Grenade.Layers.Internal.Convolution ( - im2col --- , im2colUnsafe - , vid2col - , col2im - , col2imFit - , col2vid - - , col2vidUnsafe + col2vidUnsafe , col2imUnsafe - , im2colUnsafe , vid2colUnsafe + , im2colUnsafe , fittingStarts ) where -import Control.Monad.ST -import Control.Parallel.Strategies ( parMap, rseq ) +import Control.Monad.ST ( runST ) -import Data.STRef +import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef ) import Data.Foldable ( forM_ ) import Numeric.LinearAlgebra hiding ( uniformSample, konst ) -import qualified Numeric.LinearAlgebra as LA import qualified Numeric.LinearAlgebra.Devel as U -im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -im2col nrows ncols srows scols m = - let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols - in im2colFit starts nrows ncols m - -im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -im2colFit starts nrows ncols m = - let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts - in fromRows imRows - -vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double -vid2col nrows ncols srows scols inputrows inputcols ms = - let starts = fittingStarts inputrows nrows srows inputcols ncols scols - subs = parMap rseq (im2colFit starts nrows ncols) ms - in foldl1 (|||) subs - -col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double] -col2vid krows kcols srows scols drows dcols m = - let starts = fittingStart (cols m) (krows * kcols) (krows * kcols) - r = rows m - mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts - colSts = fittingStarts drows krows srows dcols kcols scols - in parMap rseq (col2imFit colSts krows kcols drows dcols) mats - -col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -col2im krows kcols srows scols drows dcols m = - let starts = fittingStarts drows krows srows dcols kcols scols - in col2imFit starts krows kcols drows dcols m - --- | These functions are not even remotely safe, but it's only called from the statically typed --- commands, so we should be good ?!?!? --- Returns the starting sub matrix locations which fit inside the larger matrix for the --- convolution. Takes into account the stride and kernel size. -fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)] -fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh = - let rs = fittingStart nrows kernelrows steprows - cs = fittingStart ncols kernelcols stepcolsh - ls = sequence [rs, cs] - in fmap (\[a,b] -> (a,b)) ls - --- | Returns the starting sub vector which fit inside the larger vector for the --- convolution. Takes into account the stride and kernel size. -fittingStart :: Int -> Int -> Int -> [Int] -fittingStart width kernel steps = - let go left | left + kernel < width - = left : go (left + steps) - | left + kernel == width - = [left] - | otherwise - = [] -- error "Kernel and step do not fit in matrix." - in go 0 - -col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -col2imFit starts krows kcols drows dcols m = - let indicies = (\[a,b] -> (a,b)) <$> sequence [[0..(krows-1)], [0..(kcols-1)]] - convs = fmap (zip indicies . toList) . toRows $ m - pairs = zip convs starts - accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs - in accum (LA.konst 0 (drows, dcols)) (+) accums - +-- This module provides provides im2col function and friends, ala caffe. +-- +-- /* From Caffe */ +-- @ -- void col2im_cpu(const Dtype* data_col, const int channels, -- const int height, const int width, const int kernel_h, const int kernel_w, -- const int pad_h, const int pad_w, @@ -118,14 +53,19 @@ col2imFit starts krows kcols drows dcols m = -- } -- } -- } +-- @ +-- +-- | col2im function. +-- +-- Takes a column patch, and reconstitutes it into a normal image. +-- Does not do any bounds checking on the matrix, so should only +-- be called once the sizes are ensured correct. col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do - let columnMatrixRows = rows columnMatrix dataIm <- U.newMatrix 0 destinationRows destinationCols - offsetR <- newSTRef 0 offsetC <- newSTRef 0 @@ -146,6 +86,11 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d return dataIm +-- | col2vid function. +-- +-- Takes a column patch image, and reconstitutes it into a normal image with multiple channels. +-- Does not do any bounds checking on the matrix, so should only +-- be called once the sizes are ensured correct. col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double] col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do let columnMatrixRows = rows columnMatrix @@ -226,3 +171,25 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr modifySTRef inputRowRef (+1) return dataCol + + +-- | Returns the starting sub matrix locations which fit inside the larger matrix for the +-- convolution. Takes into account the stride and kernel size. +fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)] +fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh = + let rs = fittingStart nrows kernelrows steprows + cs = fittingStart ncols kernelcols stepcolsh + ls = sequence [rs, cs] + in fmap (\[a,b] -> (a,b)) ls + +-- | Returns the starting sub vector which fit inside the larger vector for the +-- convolution. Takes into account the stride and kernel size. +fittingStart :: Int -> Int -> Int -> [Int] +fittingStart width kernel steps = + let go left | left + kernel < width + = left : go (left + steps) + | left + kernel == width + = [left] + | otherwise + = [] + in go 0 diff --git a/test/Test/Grenade/Layers/Convolution.hs b/test/Test/Grenade/Layers/Convolution.hs index 7594d40..e073a82 100644 --- a/test/Test/Grenade/Layers/Convolution.hs +++ b/test/Test/Grenade/Layers/Convolution.hs @@ -27,7 +27,7 @@ prop_im2col_no_stride = once $ , 5.0, 6.0, 9.0, 10.0 , 6.0, 7.0, 10.0, 11.0 , 7.0, 8.0, 11.0, 12.0 ] - out = im2col 2 2 1 1 input + out = im2colUnsafe 2 2 1 1 input in expected === out prop_im2col_stride = once $ @@ -40,7 +40,7 @@ prop_im2col_stride = once $ , 3.0, 4.0, 7.0, 8.0 , 5.0, 6.0, 9.0, 10.0 , 7.0, 8.0, 11.0, 12.0 ] - out = im2col 2 2 1 2 input + out = im2colUnsafe 2 2 1 2 input in expected === out prop_im2col_other = once $ @@ -61,7 +61,7 @@ prop_im2col_sym_on_same_stride = once $ [ 1.0, 2.0, 3.0, 4.0 , 5.0, 6.0, 7.0, 8.0 , 9.0, 10.0, 11.0, 12.0 ] - out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input + out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input in input === out -- If there's no overlap (stride is the same size as the kernel) @@ -186,20 +186,7 @@ prop_vid2col_stride = once $ out = vid2colUnsafe 2 2 1 2 3 4 input in expected === out - prop_vid2col_invert = once $ - let input = [(3><4) - [ 1.0, 2.0, 3.0, 4.0 - , 5.0, 6.0, 7.0, 8.0 - , 9.0, 10.0, 11.0, 12.0 ] - , (3><4) - [ 21.0, 22.0, 23.0, 24.0 - , 25.0, 26.0, 27.0, 28.0 - , 29.0, 30.0, 31.0, 32.0 ] ] - out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input - in input === out - -prop_vid2col_invert_unsafe = once $ let input = [(3><4) [ 1.0, 2.0, 3.0, 4.0 , 5.0, 6.0, 7.0, 8.0