Add stupidly fast im2col in c

This commit is contained in:
Huw Campbell 2016-12-12 23:35:00 +11:00
parent f04bfebfee
commit 3830e55a7c
7 changed files with 159 additions and 31 deletions

View File

@ -11,6 +11,10 @@ main = defaultMain [
, bench "im2col 28x28" $ whnf (im2colUnsafe 5 5 1 1) ((28><28) [1..])
, bench "im2col 100x100" $ whnf (im2colUnsafe 10 10 1 1) ((100><100) [1..])
]
, bgroup "im2col_c" [ bench "im2col_c 3x4" $ whnf (im2col_c 2 2 1 1) ((3><4) [1..])
, bench "im2col_c 28x28" $ whnf (im2col_c 5 5 1 1) ((28><28) [1..])
, bench "im2col_c 100x100" $ whnf (im2col_c 10 10 1 1) ((100><100) [1..])
]
, bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2imUnsafe 2 2 1 1 3 4) ((6><4) [1..])
, bench "col2im 28x28" $ whnf (col2imUnsafe 5 5 1 1 28 28) ((576><25) [1..])
, bench "col2im 100x100" $ whnf (col2imUnsafe 10 10 1 1 100 100) ((8281><100) [1..])

30
cbits/im2col.c Normal file
View File

@ -0,0 +1,30 @@
#include "im2col.h"
inline int is_a_ge_zero_and_a_lt_b(int a, int b) {
return a >= 0 && a < b;
}
void im2col_cpu(const double* data_im, int dataOffset, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
double* data_col) {
data_im += dataOffset;
const int output_h = (height - kernel_h) / stride_h + 1;
const int output_w = (width - kernel_w) / stride_w + 1;
const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) {
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
for (int fitting_width = 0; fitting_width <= (width - kernel_w); fitting_width += stride_w) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = fitting_height + kernel_row;
int input_col = fitting_width + kernel_col;
*(data_col++) = data_im[input_row * width + input_col];
}
}
}
}
}
}

7
cbits/im2col.h Normal file
View File

@ -0,0 +1,7 @@
#include <stdio.h>
#include <stdint.h>
void im2col_cpu(const double* data_im, int dataOffset, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
double* data_col);

View File

@ -10,6 +10,10 @@ cabal-version: >= 1.8
build-type: Simple
description: grenade.
extra-source-files:
cbits/im2col.h
cbits/im2col.c
library
build-depends:
base >= 4.8 && < 5
@ -53,6 +57,8 @@ library
Grenade.Layers.Internal.Convolution
Grenade.Layers.Internal.Pooling
includes: cbits/im2col.h
c-sources: cbits/im2col.c
executable feedforward
ghc-options: -Wall -threaded -O2

View File

@ -155,7 +155,7 @@ instance ( KnownNat kernelRows
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
c = im2colUnsafe kx ky sx sy ex
c = im2col_c kx ky sx sy ex
mt = c LA.<> ek
r = col2vidUnsafe 1 1 1 1 ox oy mt
rs = fmap (fromJust . create) r
@ -172,7 +172,7 @@ instance ( KnownNat kernelRows
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
c = im2colUnsafe kx ky sx sy ex
c = im2col_c kx ky sx sy ex
eo = vecToList $ fmap extract dEdy
ek = extract kernel

View File

@ -1,8 +1,11 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Grenade.Layers.Internal.Convolution (
col2vidUnsafe
, col2imUnsafe
, vid2colUnsafe
, im2colUnsafe
, im2col_c
, fittingStarts
, unsafeModifyMatrix
) where
@ -13,11 +16,15 @@ import Data.STRef ( newSTRef, modifySTRef', writeSTRef, readSTRef )
import Data.Foldable ( forM_ )
import Data.Traversable ( forM )
import Foreign ( mallocForeignPtrArray0, withForeignPtr )
import Foreign.Ptr ( Ptr )
import Foreign.Storable( Storable )
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra.Devel as U
import System.IO.Unsafe ( unsafePerformIO )
-- This module provides provides im2col function and friends, ala caffe.
--
-- /* From Caffe */
@ -105,13 +112,13 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
offsetR <- newSTRef 0
offsetC <- newSTRef 0
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
inputColumn <- newSTRef 0
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
inputColumn <- newSTRef offsetM
offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC
forM_ [offsetR' .. offsetR' + kernelRows -1] $ \kr ->
forM_ [offsetC' .. offsetC' + kernelColumns -1] $ \kc -> do
ic <- readSTRef inputColumn
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix ir (ic + offsetM))
unsafeModifyMatrix dataIm kr kc (+ U.atM' columnMatrix ir ic)
modifySTRef' inputColumn (+1)
if offsetC' + kernelColumns < destinationCols
@ -120,6 +127,62 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
U.unsafeFreezeMatrix dataIm
im2col_c :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
im2col_c kernelRows kernelColumns strideRows strideColumns dataIm =
let height = rows dataIm
width = cols dataIm
vec = flatten dataIm
rowOut = (height - kernelRows) `div` strideRows + 1
colOut = (width - kernelColumns) `div` strideColumns + 1
kernelSize = kernelRows * kernelColumns
numberOfPatches = rowOut * colOut
in unsafePerformIO $ do
outPtr <- mallocForeignPtrArray0 (numberOfPatches * kernelSize)
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
withForeignPtr inPtr $ \inPtr' ->
withForeignPtr outPtr $ \outPtr' ->
im2col_cpu inPtr' inOffset 1 height width kernelRows kernelColumns strideRows strideColumns outPtr'
let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * kernelSize)
return $ U.matrixFromVector U.RowMajor numberOfPatches kernelSize matVec
foreign import ccall safe
im2col_cpu
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
unsafeModifyMatrix :: (Storable t) => U.STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
{-# INLINE unsafeModifyMatrix #-}
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
in concatMap ( \r -> fmap (\c -> (r , c)) cs ) rs
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= [left]
| otherwise
= []
in go 0
-- | Old functions (useful for sanity checking and benchmarking)
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
@ -168,28 +231,3 @@ im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatr
modifySTRef' inputRowRef (+1)
return dataCol
unsafeModifyMatrix :: (Storable t) => U.STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
{-# INLINE unsafeModifyMatrix #-}
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
-- convolution. Takes into account the stride and kernel size.
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
let rs = fittingStart nrows kernelrows steprows
cs = fittingStart ncols kernelcols stepcolsh
in concatMap ( \r -> fmap (\c -> (r , c)) cs ) rs
-- | Returns the starting sub vector which fit inside the larger vector for the
-- convolution. Takes into account the stride and kernel size.
fittingStart :: Int -> Int -> Int -> [Int]
fittingStart width kernel steps =
let go left | left + kernel < width
= left : go (left + steps)
| left + kernel == width
= [left]
| otherwise
= []
in go 0

View File

@ -30,6 +30,21 @@ prop_im2col_no_stride = once $
out = im2colUnsafe 2 2 1 1 input
in expected === out
prop_im2col_c = once $
let input = (3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
expected = (6><4)
[ 1.0, 2.0, 5.0, 6.0
, 2.0, 3.0, 6.0, 7.0
, 3.0, 4.0, 7.0, 8.0
, 5.0, 6.0, 9.0, 10.0
, 6.0, 7.0, 10.0, 11.0
, 7.0, 8.0, 11.0, 12.0 ]
out = im2col_c 2 2 1 1 input
in expected === out
prop_im2col_stride = once $
let input = (3><4)
[ 1.0, 2.0, 3.0, 4.0
@ -43,6 +58,19 @@ prop_im2col_stride = once $
out = im2colUnsafe 2 2 1 2 input
in expected === out
prop_im2col_c_stride = once $
let input = (3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
expected = (4><4)
[ 1.0, 2.0, 5.0, 6.0
, 3.0, 4.0, 7.0, 8.0
, 5.0, 6.0, 9.0, 10.0
, 7.0, 8.0, 11.0, 12.0 ]
out = im2col_c 2 2 1 2 input
in expected === out
prop_im2col_other = once $
let input = (3><4)
[ 1.0, 2.0, 3.0, 4.0
@ -54,6 +82,21 @@ prop_im2col_other = once $
out = im2colUnsafe 3 2 1 2 input
in expected === out
prop_im2col_c_other = once $
let input = (3><4)
[ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ]
expected = (2><6)
[ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0
, 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ]
out = im2col_c 3 2 1 2 input
in expected === out
prop_im2col_bigger = once $
let input = (7><7) [ 1.0 .. ]
in im2colUnsafe 5 5 2 2 input === im2col_c 5 5 2 2 input
-- If there's no overlap (stride is the same size as the kernel)
-- then col2im . im2col should be symmetric.
prop_im2col_sym_on_same_stride = once $