Make Pad 3D faster.

Using dependent types in the deeper functions and
requiring a Proxy to reach them meant we required
dictionary passing to get the Nats. This made the
pad and crop layers almost 1000 times slower than
they should have been.
This commit is contained in:
Huw Campbell 2017-02-03 21:04:27 +11:00
parent dbeb962ae6
commit 1e461cb07a
13 changed files with 167 additions and 187 deletions

View File

@ -1,26 +1,55 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Criterion.Main
import Grenade
import Grenade.Layers.Internal.Convolution
import Grenade.Layers.Internal.Pooling
import Numeric.LinearAlgebra
main :: IO ()
main = defaultMain [
bgroup "im2col" [ bench "im2col 3x4" $ whnf (im2col 2 2 1 1) ((3><4) [1..])
, bench "im2col 28x28" $ whnf (im2col 5 5 1 1) ((28><28) [1..])
, bench "im2col 100x100" $ whnf (im2col 10 10 1 1) ((100><100) [1..])
]
, bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2im 2 2 1 1 3 4) ((6><4) [1..])
, bench "col2im 28x28" $ whnf (col2im 5 5 1 1 28 28) ((576><25) [1..])
, bench "col2im 100x100" $ whnf (col2im 10 10 1 1 100 100) ((8281><100) [1..])
]
, bgroup "poolfw" [ bench "poolforwards 3x4" $ whnf (poolForward 1 3 4 2 2 1 1) ((3><4) [1..])
, bench "poolforwards 28x28" $ whnf (poolForward 1 28 28 5 5 1 1) ((28><28) [1..])
, bench "poolforwards 100x100" $ whnf (poolForward 1 100 100 10 10 1 1) ((100><100) [1..])
]
, bgroup "poolbw" [ bench "poolbackwards 3x4" $ whnf (poolBackward 1 3 4 2 2 1 1 ((3><4) [1..])) ((2><3) [1..])
, bench "poolbackwards 28x28" $ whnf (poolBackward 1 28 28 5 5 1 1 ((28><28) [1..])) ((24><24) [1..])
, bench "poolbackwards 100x100" $ whnf (poolBackward 1 100 100 10 10 1 1 ((100><100) [1..])) ((91><91) [1..])
]
]
main = do
x :: S ('D2 60 60 ) <- randomOfShape
y :: S ('D3 60 60 1) <- randomOfShape
defaultMain [
bgroup "im2col" [ bench "im2col 3x4" $ whnf (im2col 2 2 1 1) ((3><4) [1..])
, bench "im2col 28x28" $ whnf (im2col 5 5 1 1) ((28><28) [1..])
, bench "im2col 100x100" $ whnf (im2col 10 10 1 1) ((100><100) [1..])
]
, bgroup "col2im" [ bench "col2im 3x4" $ whnf (col2im 2 2 1 1 3 4) ((6><4) [1..])
, bench "col2im 28x28" $ whnf (col2im 5 5 1 1 28 28) ((576><25) [1..])
, bench "col2im 100x100" $ whnf (col2im 10 10 1 1 100 100) ((8281><100) [1..])
]
, bgroup "poolfw" [ bench "poolforwards 3x4" $ whnf (poolForward 1 3 4 2 2 1 1) ((3><4) [1..])
, bench "poolforwards 28x28" $ whnf (poolForward 1 28 28 5 5 1 1) ((28><28) [1..])
, bench "poolforwards 100x100" $ whnf (poolForward 1 100 100 10 10 1 1) ((100><100) [1..])
]
, bgroup "poolbw" [ bench "poolbackwards 3x4" $ whnf (poolBackward 1 3 4 2 2 1 1 ((3><4) [1..])) ((2><3) [1..])
, bench "poolbackwards 28x28" $ whnf (poolBackward 1 28 28 5 5 1 1 ((28><28) [1..])) ((24><24) [1..])
, bench "poolbackwards 100x100" $ whnf (poolBackward 1 100 100 10 10 1 1 ((100><100) [1..])) ((91><91) [1..])
]
, bgroup "padcrop" [ bench "pad 2D 60x60" $ whnf (testRun2D Pad) x
, bench "pad 3D 60x60" $ whnf (testRun3D Pad) y
, bench "crop 2D 60x60" $ whnf (testRun2D' Crop) x
, bench "crop 3D 60x60" $ whnf (testRun3D' Crop) y
]
]
testRun2D :: Pad 1 1 1 1 -> S ('D2 60 60) -> S ('D2 62 62)
testRun2D = snd ... runForwards
testRun3D :: Pad 1 1 1 1 -> S ('D3 60 60 1) -> S ('D3 62 62 1)
testRun3D = snd ... runForwards
testRun2D' :: Crop 1 1 1 1 -> S ('D2 60 60) -> S ('D2 58 58)
testRun2D' = snd ... runForwards
testRun3D' :: Crop 1 1 1 1 -> S ('D3 60 60 1) -> S ('D3 58 58 1)
testRun3D' = snd ... runForwards
(...) :: (a -> b) -> (c -> d -> a) -> c -> d -> b
(...) = (.) . (.)

View File

@ -1,23 +1,21 @@
#include "pad.h"
void pad_cpu(const double* data, const int channels,
void pad_cpu(double* data, const int channels,
const int height, const int width, const int pad_left, const int pad_top,
const int pad_right, const int pad_bottom,
double* data_padded) {
const int pad_width = width + pad_left + pad_right;
const int pad_height = height + pad_top + pad_bottom;
const int channel_size = height * width;
memset(data_padded, 0, pad_height * pad_width * channels * sizeof(double));
for (int channel = 0; channel < channels; channel++) {
double* px = data_padded + (pad_width * pad_top + pad_left) + channel * (pad_width * pad_height);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
*(px++) = data[y * width + x + channel_size * channel];
}
px += pad_left + pad_right;
memcpy(px, data, sizeof(double) * width);
px += pad_width;
data += width;
}
}
}
@ -30,15 +28,12 @@ void crop_cpu(double* data, const int channels,
const int crop_width = width + crop_left + crop_right;
const int crop_height = height + crop_top + crop_bottom;
const int channel_size = height * width;
for (int channel = 0; channel < channels; channel++) {
double* px = data + (crop_width * crop_top + crop_left) + channel * (crop_width * crop_height);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
data_cropped[y * width + x + channel_size * channel] = *(px++);
}
px += crop_left + crop_right;
memcpy(data_cropped, px, sizeof(double) * width);
px += crop_width;
data_cropped += width;
}
}
}

View File

@ -2,7 +2,7 @@
#include <stdint.h>
#include <string.h>
void pad_cpu(const double* data_im, const int channels,
void pad_cpu(double* data_im, const int channels,
const int height, const int width, const int pad_left, const int pad_top,
const int pad_right, const int pad_bottom,
double* data_col);

View File

@ -22,22 +22,19 @@ library
build-depends:
base >= 4.8 && < 5
, bytestring == 0.10.*
, async
, containers
, deepseq
, either == 4.4.*
, cereal
, exceptions == 0.8.*
, hmatrix
, hmatrix == 0.18.*
, MonadRandom
, mtl >= 2.2.1 && < 2.3
, parallel == 3.2.*
, primitive
, reflection
, text == 1.2.*
, transformers
, singletons
, vector
, singletons >= 2.1 && < 2.3
, vector == 0.11.*
ghc-options:
-Wall
@ -100,7 +97,7 @@ executable feedforward
, bytestring
, cereal
, either
, optparse-applicative == 0.12.*
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix
@ -115,7 +112,7 @@ executable mnist
, grenade
, attoparsec
, either
, optparse-applicative == 0.12.*
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
@ -131,7 +128,7 @@ executable recurrent
, grenade
, attoparsec
, either
, optparse-applicative == 0.12.*
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
@ -149,7 +146,7 @@ executable shakespeare
, bytestring
, cereal
, either
, optparse-applicative == 0.12.*
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19

View File

@ -10,6 +10,7 @@ import Data.List ( foldl' )
import qualified Data.ByteString as B
import Data.Serialize
import Data.Semigroup ( (<>) )
import GHC.TypeLits

View File

@ -12,6 +12,7 @@ import Control.Monad.Trans.Except
import qualified Data.Attoparsec.Text as A
import Data.List ( foldl' )
import Data.Semigroup ( (<>) )
import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified Data.Vector.Storable as V
@ -24,7 +25,6 @@ import Options.Applicative
import Grenade
import Grenade.Utils.OneHot
-- The definition of our convolutional neural network.
-- In the type signature, we have a type level list of shapes which are passed between the layers.
-- One can see that the images we are inputing are two dimensional with 28 * 28 pixels.

View File

@ -14,6 +14,7 @@ import Data.List ( unfoldr )
#else
import Data.List ( cycle, unfoldr )
#endif
import Data.Semigroup ( (<>) )
import qualified Numeric.LinearAlgebra.Static as SA

View File

@ -13,6 +13,7 @@ import Control.Monad.Trans.Except
import Data.Char ( isUpper, toUpper, toLower )
import Data.List ( foldl' )
import Data.Maybe ( fromMaybe )
import Data.Semigroup ( (<>) )
import qualified Data.Vector as V
import Data.Vector ( Vector )

View File

@ -18,7 +18,7 @@ module Grenade.Core.Layer (
, UpdateLayer (..)
) where
import Control.Monad.Random (MonadRandom)
import Control.Monad.Random ( MonadRandom )
import Data.List ( foldl' )

View File

@ -82,18 +82,30 @@ instance ( KnownNat cropLeft
, (outputColumns + cropLeft + cropRight) ~ inputColumns
) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
type Tape (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = ()
runForwards Crop input =
let cropl = Proxy :: Proxy cropLeft
cropt = Proxy :: Proxy cropTop
cropr = Proxy :: Proxy cropRight
cropb = Proxy :: Proxy cropBottom
cropped = crop cropl cropt cropr cropb input
in ((), cropped)
runForwards Crop (S3D input) =
let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
padt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
padr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom)
inr = fromIntegral $ natVal (Proxy :: Proxy inputRows)
inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
outr = fromIntegral $ natVal (Proxy :: Proxy outputRows)
outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
m = extract input
cropped = crop ch padl padt padr padb outr outc inr inc m
in ((), S3D . fromJust . create $ cropped)
runBackwards Crop () gradient =
let cropl = Proxy :: Proxy cropLeft
cropt = Proxy :: Proxy cropTop
cropr = Proxy :: Proxy cropRight
cropb = Proxy :: Proxy cropBottom
padded = pad cropl cropt cropr cropb gradient
in ((), padded)
runBackwards Crop () (S3D gradient) =
let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
padt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
padr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom)
inr = fromIntegral $ natVal (Proxy :: Proxy inputRows)
inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
outr = fromIntegral $ natVal (Proxy :: Proxy outputRows)
outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
m = extract gradient
padded = pad ch padl padt padr padb outr outc inr inc m
in ((), S3D . fromJust . create $ padded)

View File

@ -1,123 +1,53 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
#if __GLASGOW_HASKELL__ == 800
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
#endif
module Grenade.Layers.Internal.Pad (
pad
, crop
) where
import Data.Maybe ( fromJust )
import Data.Proxy
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
import GHC.TypeLits
import Grenade.Core
import Foreign ( mallocForeignPtrArray, withForeignPtr )
import Foreign.Ptr ( Ptr )
import Numeric.LinearAlgebra ( flatten )
import Numeric.LinearAlgebra.Static ( extract )
import Numeric.LinearAlgebra ( flatten, Matrix )
import qualified Numeric.LinearAlgebra.Devel as U
import System.IO.Unsafe ( unsafePerformIO )
pad :: forall padLeft padTop padRight padBottom rows rows' cols cols' channels.
( KnownNat padLeft
, KnownNat padTop
, KnownNat padRight
, KnownNat padBottom
, KnownNat rows
, KnownNat rows'
, KnownNat cols
, KnownNat cols'
, KnownNat channels
, rows' ~ (rows + padTop + padBottom)
, cols' ~ (cols + padLeft + padRight)
, KnownNat (rows' * channels)
) => Proxy padLeft
-> Proxy padTop
-> Proxy padRight
-> Proxy padBottom
-> S ('D3 rows cols channels)
-> S ('D3 rows' cols' channels)
pad _ _ _ _ (S3D m) =
let channels = fromIntegral $ natVal (Proxy :: Proxy channels)
padLeft = fromIntegral $ natVal (Proxy :: Proxy padLeft)
padTop = fromIntegral $ natVal (Proxy :: Proxy padTop)
padRight = fromIntegral $ natVal (Proxy :: Proxy padRight)
padBottom = fromIntegral $ natVal (Proxy :: Proxy padBottom)
rows = fromIntegral $ natVal (Proxy :: Proxy rows)
cols = fromIntegral $ natVal (Proxy :: Proxy cols)
rows' = fromIntegral $ natVal (Proxy :: Proxy rows')
cols' = fromIntegral $ natVal (Proxy :: Proxy cols')
outMatSize = rows' * cols' * channels
pad :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
pad channels padLeft padTop padRight padBottom rows cols rows' cols' m
= let outMatSize = rows' * cols' * channels
vec = flatten m
in unsafePerformIO $ do
outPtr <- mallocForeignPtrArray outMatSize
let (inPtr, _) = U.unsafeToForeignPtr0 vec
vec = flatten (extract m)
in unsafePerformIO $ do
outPtr <- mallocForeignPtrArray outMatSize
let (inPtr, _) = U.unsafeToForeignPtr0 vec
withForeignPtr inPtr $ \inPtr' ->
withForeignPtr outPtr $ \outPtr' ->
pad_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr'
withForeignPtr inPtr $ \inPtr' ->
withForeignPtr outPtr $ \outPtr' ->
pad_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr'
let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize
return (fromJust $ fromStorable matVec)
let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize
return (U.matrixFromVector U.RowMajor (rows' * channels) cols' matVec)
{-# INLINE pad #-}
foreign import ccall unsafe
pad_cpu
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
crop :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
crop channels padLeft padTop padRight padBottom rows cols _ _ m
= let outMatSize = rows * cols * channels
vec = flatten m
in unsafePerformIO $ do
outPtr <- mallocForeignPtrArray outMatSize
let (inPtr, _) = U.unsafeToForeignPtr0 vec
crop :: forall padLeft padTop padRight padBottom rows rows' cols cols' channels.
( KnownNat padLeft
, KnownNat padTop
, KnownNat padRight
, KnownNat padBottom
, KnownNat rows
, KnownNat cols
, KnownNat cols'
, KnownNat channels
, rows' ~ (rows + padTop + padBottom)
, cols' ~ (cols + padLeft + padRight)
, KnownNat (rows * channels)
) => Proxy padLeft
-> Proxy padTop
-> Proxy padRight
-> Proxy padBottom
-> S ('D3 rows' cols' channels)
-> S ('D3 rows cols channels)
crop _ _ _ _ (S3D m) =
let channels = fromIntegral $ natVal (Proxy :: Proxy channels)
padLeft = fromIntegral $ natVal (Proxy :: Proxy padLeft)
padTop = fromIntegral $ natVal (Proxy :: Proxy padTop)
padRight = fromIntegral $ natVal (Proxy :: Proxy padRight)
padBottom = fromIntegral $ natVal (Proxy :: Proxy padBottom)
rows = fromIntegral $ natVal (Proxy :: Proxy rows)
cols = fromIntegral $ natVal (Proxy :: Proxy cols)
outMatSize = rows * cols * channels
withForeignPtr inPtr $ \inPtr' ->
withForeignPtr outPtr $ \outPtr' ->
crop_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr'
vec = flatten (extract m)
in unsafePerformIO $ do
outPtr <- mallocForeignPtrArray outMatSize
let (inPtr, _) = U.unsafeToForeignPtr0 vec
withForeignPtr inPtr $ \inPtr' ->
withForeignPtr outPtr $ \outPtr' ->
crop_cpu inPtr' channels rows cols padLeft padTop padRight padBottom outPtr'
let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize
return (fromJust $ fromStorable matVec)
let matVec = U.unsafeFromForeignPtr0 outPtr outMatSize
return (U.matrixFromVector U.RowMajor (rows * channels) cols matVec)
foreign import ccall unsafe
crop_cpu

View File

@ -70,7 +70,6 @@ instance ( KnownNat padLeft
vs = subMatrix (padt, padl) (nrows, ncols) m
in ((), S2D . fromJust . create $ vs)
-- | A two dimentional image can be padped.
instance ( KnownNat padLeft
, KnownNat padTop
@ -87,18 +86,30 @@ instance ( KnownNat padLeft
, (inputColumns + padLeft + padRight) ~ outputColumns
) => Layer (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
type Tape (Pad padLeft padTop padRight padBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = ()
runForwards Pad input =
let padl = Proxy :: Proxy padLeft
padt = Proxy :: Proxy padTop
padr = Proxy :: Proxy padRight
padb = Proxy :: Proxy padBottom
padded = pad padl padt padr padb input
in ((), padded)
runForwards Pad (S3D input) =
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
padr = fromIntegral $ natVal (Proxy :: Proxy padRight)
padb = fromIntegral $ natVal (Proxy :: Proxy padBottom)
outr = fromIntegral $ natVal (Proxy :: Proxy outputRows)
outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
inr = fromIntegral $ natVal (Proxy :: Proxy inputRows)
inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
m = extract input
padded = pad ch padl padt padr padb inr inc outr outc m
in ((), S3D . fromJust . create $ padded)
runBackwards Pad () gradient =
let padl = Proxy :: Proxy padLeft
padt = Proxy :: Proxy padTop
padr = Proxy :: Proxy padRight
padb = Proxy :: Proxy padBottom
cropped = crop padl padt padr padb gradient
in ((), cropped)
runBackwards Pad () (S3D gradient) =
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
padr = fromIntegral $ natVal (Proxy :: Proxy padRight)
padb = fromIntegral $ natVal (Proxy :: Proxy padBottom)
outr = fromIntegral $ natVal (Proxy :: Proxy outputRows)
outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
inr = fromIntegral $ natVal (Proxy :: Proxy inputRows)
inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
m = extract gradient
cropped = crop ch padl padt padr padb inr inc outr outc m
in ((), S3D . fromJust . create $ cropped)

View File

@ -29,23 +29,26 @@ genShape
, genD2
, genD3
]
where
genD1 = do
n <- genNat
return $ case n of
SomeNat (_ :: Proxy x) -> SomeSing (sing :: Sing ('D1 x))
genD2 = do
n <- genNat
m <- genNat
return $ case (n, m) of
(SomeNat (_ :: Proxy x), SomeNat (_ :: Proxy y)) -> SomeSing (sing :: Sing ('D2 x y))
genD1 :: Jack (SomeSing Shape)
genD1 = do
n <- genNat
return $ case n of
SomeNat (_ :: Proxy x) -> SomeSing (sing :: Sing ('D1 x))
genD3 = do
n <- genNat
m <- genNat
o <- genNat
return $ case (n, m, o) of
(SomeNat (px :: Proxy x), SomeNat (_ :: Proxy y), SomeNat (pz :: Proxy z)) ->
case natDict px %* natDict pz of
Dict -> SomeSing (sing :: Sing ('D3 x y z))
genD2 :: Jack (SomeSing Shape)
genD2 = do
n <- genNat
m <- genNat
return $ case (n, m) of
(SomeNat (_ :: Proxy x), SomeNat (_ :: Proxy y)) -> SomeSing (sing :: Sing ('D2 x y))
genD3 :: Jack (SomeSing Shape)
genD3 = do
n <- genNat
m <- genNat
o <- genNat
return $ case (n, m, o) of
(SomeNat (px :: Proxy x), SomeNat (_ :: Proxy y), SomeNat (pz :: Proxy z)) ->
case natDict px %* natDict pz of
Dict -> SomeSing (sing :: Sing ('D3 x y z))