mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Initial commit
This commit is contained in:
commit
08afd74cde
15
.gitmodules
vendored
Normal file
15
.gitmodules
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
[submodule "lib/disorder"]
|
||||||
|
path = lib/disorder
|
||||||
|
url = git@github.com:ambiata/disorder.hs
|
||||||
|
[submodule "lib/x"]
|
||||||
|
path = lib/x
|
||||||
|
url = git@github.com:ambiata/x.git
|
||||||
|
[submodule "lib/p"]
|
||||||
|
path = lib/p
|
||||||
|
url = git@github.com:ambiata/p.git
|
||||||
|
[submodule "lib/disorder.hs"]
|
||||||
|
path = lib/disorder.hs
|
||||||
|
url = git@github.com:ambiata/disorder.hs.git
|
||||||
|
[submodule "lib/hmatrix"]
|
||||||
|
path = lib/hmatrix
|
||||||
|
url = git@github.com:albertoruiz/hmatrix.git
|
58
README.md
Normal file
58
README.md
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
Grenade
|
||||||
|
=======
|
||||||
|
|
||||||
|
```
|
||||||
|
First shalt thou take out the Holy Pin, then shalt thou count to three, no more, no less.
|
||||||
|
Three shall be the number thou shalt count, and the number of the counting shall be three.
|
||||||
|
Four shalt thou not count, neither count thou two, excepting that thou then proceed to three.
|
||||||
|
Five is right out.
|
||||||
|
```
|
||||||
|
|
||||||
|
💣 Machine learning which might blow up in your face 💣
|
||||||
|
|
||||||
|
Grenade is a type safe, practical and pretty quick neural network library for concise and precise
|
||||||
|
specifications of complex networks in Haskell.
|
||||||
|
|
||||||
|
As an example, a network which can achieve less than 1.5% error on mnist can be specified and
|
||||||
|
initialised with random weights in under 10 lines of code with
|
||||||
|
```haskell
|
||||||
|
randomMnistNet :: MonadRandom m => m (Network Identity '[('D2 28 28), ('D3 24 24 10), ('D3 12 12 10), ('D3 12 12 10), ('D3 8 8 16), ('D3 4 4 16), ('D1 256), ('D1 256), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
|
||||||
|
randomMnistNet = do
|
||||||
|
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
|
||||||
|
let b :: Pooling 2 2 2 2 = Pooling
|
||||||
|
c :: Convolution 10 16 5 5 1 1 <- randomConvolution
|
||||||
|
let d :: Pooling 2 2 2 2 = Pooling
|
||||||
|
e :: FullyConnected 256 80 <- randomFullyConnected
|
||||||
|
f :: FullyConnected 80 10 <- randomFullyConnected
|
||||||
|
return $ a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
|
||||||
|
```
|
||||||
|
|
||||||
|
The network can be thought of as a heterogeneous list of layers, and its type signature includes a type
|
||||||
|
level list of the shapes of the data passed between the layers of the network.
|
||||||
|
|
||||||
|
In the above example, the input layer can be seen to be a two dimensional (`D2`) image with 28 by 28 pixels.
|
||||||
|
The last item in the list is one dimensional (`D1`) with 10 values, representing the categories of the mnist data.
|
||||||
|
|
||||||
|
Layers in Grenade are represented as Haskell classes, so creating one's own is easy in downstream code. If the shapes
|
||||||
|
of a network are not specified correctly and a layer can not sensibly perform the operation between two shapes, then
|
||||||
|
it will result in a compile time error.
|
||||||
|
|
||||||
|
Thanks
|
||||||
|
------
|
||||||
|
Writing a library like this has been on my mind for a while now, but a big shout out must go to Justin Le, whose
|
||||||
|
dependently typed fully connected network inspired me to get cracking, gave many ideas for the type level tools I
|
||||||
|
needed, and was a great starting point for writing this library.
|
||||||
|
|
||||||
|
Performance
|
||||||
|
-----------
|
||||||
|
Grenade is backed by hmatrix and blas, and uses a pretty clever convolution trick popularised by Caffe, which
|
||||||
|
is surprisingly effective and fast. So for many small scale problems it should be sufficient.
|
||||||
|
|
||||||
|
That said, it's currently stuck on a single core and doesn't hit up the GPU, so there's a fair bit of performance
|
||||||
|
sitting there begging.
|
||||||
|
|
||||||
|
Training 15 generations over Kaggle's mnist training data took a few hours.
|
||||||
|
|
||||||
|
Contributing
|
||||||
|
------------
|
||||||
|
Contributions are welcome.
|
99
grenade.cabal
Normal file
99
grenade.cabal
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
name: grenade
|
||||||
|
version: 0.0.1
|
||||||
|
license: AllRightsReserved
|
||||||
|
author: Ambiata <info@ambiata.com>
|
||||||
|
maintainer: Ambiata <info@ambiata.com>
|
||||||
|
copyright: (c) 2015 Ambiata.
|
||||||
|
synopsis: grenade
|
||||||
|
category: System
|
||||||
|
cabal-version: >= 1.8
|
||||||
|
build-type: Simple
|
||||||
|
description: grenade.
|
||||||
|
|
||||||
|
library
|
||||||
|
build-depends:
|
||||||
|
base >= 4.8 && < 5
|
||||||
|
, bytestring == 0.10.*
|
||||||
|
, either == 4.4.*
|
||||||
|
, exceptions == 0.8.*
|
||||||
|
, hmatrix
|
||||||
|
, MonadRandom
|
||||||
|
, mtl >= 2.2.1 && < 2.3
|
||||||
|
, parallel == 3.2.*
|
||||||
|
, text == 1.2.*
|
||||||
|
, transformers
|
||||||
|
, singletons
|
||||||
|
|
||||||
|
ghc-options:
|
||||||
|
-Wall
|
||||||
|
hs-source-dirs:
|
||||||
|
src
|
||||||
|
|
||||||
|
|
||||||
|
exposed-modules:
|
||||||
|
Grenade
|
||||||
|
Grenade.Core.Network
|
||||||
|
Grenade.Core.Vector
|
||||||
|
Grenade.Core.Runner
|
||||||
|
Grenade.Core.Shape
|
||||||
|
Grenade.Layers.Convolution
|
||||||
|
Grenade.Layers.Dropout
|
||||||
|
Grenade.Layers.FullyConnected
|
||||||
|
Grenade.Layers.Flatten
|
||||||
|
Grenade.Layers.Fuse
|
||||||
|
Grenade.Layers.Logit
|
||||||
|
Grenade.Layers.Relu
|
||||||
|
Grenade.Layers.Tanh
|
||||||
|
Grenade.Layers.Pooling
|
||||||
|
|
||||||
|
|
||||||
|
executable feedforward
|
||||||
|
ghc-options: -Wall -threaded -O2
|
||||||
|
main-is: main/feedforward.hs
|
||||||
|
build-depends: base
|
||||||
|
, grenade
|
||||||
|
, attoparsec
|
||||||
|
, either
|
||||||
|
, optparse-applicative == 0.12.*
|
||||||
|
, text
|
||||||
|
, mtl >= 2.2.1 && < 2.3
|
||||||
|
, hmatrix
|
||||||
|
, transformers
|
||||||
|
, singletons
|
||||||
|
, MonadRandom
|
||||||
|
|
||||||
|
executable mnist
|
||||||
|
ghc-options: -Wall -threaded -O2
|
||||||
|
main-is: main/mnist.hs
|
||||||
|
build-depends: base
|
||||||
|
, grenade
|
||||||
|
, attoparsec
|
||||||
|
, either
|
||||||
|
, optparse-applicative == 0.12.*
|
||||||
|
, text
|
||||||
|
, mtl >= 2.2.1 && < 2.3
|
||||||
|
, hmatrix
|
||||||
|
, transformers
|
||||||
|
, singletons
|
||||||
|
, MonadRandom
|
||||||
|
|
||||||
|
|
||||||
|
test-suite test
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
|
||||||
|
main-is: test.hs
|
||||||
|
|
||||||
|
ghc-options: -Wall -threaded -O2
|
||||||
|
|
||||||
|
hs-source-dirs:
|
||||||
|
test
|
||||||
|
|
||||||
|
build-depends:
|
||||||
|
base >= 4.8 && < 5
|
||||||
|
, grenade
|
||||||
|
, ambiata-disorder-core
|
||||||
|
, hmatrix
|
||||||
|
, mtl
|
||||||
|
, text
|
||||||
|
, QuickCheck == 2.7.*
|
||||||
|
, quickcheck-instances == 0.3.*
|
1
lib/disorder.hs
Submodule
1
lib/disorder.hs
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 43d08f1b4b3e0d43aa3233b7b5a3ea785c1d357b
|
1
lib/hmatrix
Submodule
1
lib/hmatrix
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 9aade51bd0bb6339cfa8aca014bd96f801d9b19e
|
113
mafia
Executable file
113
mafia
Executable file
@ -0,0 +1,113 @@
|
|||||||
|
#!/bin/sh -eu
|
||||||
|
|
||||||
|
fetch_latest () {
|
||||||
|
if [ -z ${MAFIA_TEST_MODE+x} ]; then
|
||||||
|
TZ=$(date +"%T")
|
||||||
|
curl --silent "https://raw.githubusercontent.com/ambiata/mafia/master/script/mafia?$TZ"
|
||||||
|
else
|
||||||
|
cat ../script/mafia
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
latest_version () {
|
||||||
|
git ls-remote https://github.com/ambiata/mafia | grep refs/heads/master | cut -f 1
|
||||||
|
}
|
||||||
|
|
||||||
|
local_version () {
|
||||||
|
awk '/^# Version: / { print $3; exit 0; }' $0
|
||||||
|
}
|
||||||
|
|
||||||
|
run_upgrade () {
|
||||||
|
MAFIA_TEMP=$(mktemp 2>/dev/null || mktemp -t 'upgrade_mafia')
|
||||||
|
|
||||||
|
clean_up () {
|
||||||
|
rm -f "$MAFIA_TEMP"
|
||||||
|
}
|
||||||
|
|
||||||
|
trap clean_up EXIT
|
||||||
|
|
||||||
|
MAFIA_CUR="$0"
|
||||||
|
|
||||||
|
if [ -L "$MAFIA_CUR" ]; then
|
||||||
|
echo 'Refusing to overwrite a symlink; run `upgrade` from the canonical path.' >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Checking for a new version of mafia ..."
|
||||||
|
fetch_latest > $MAFIA_TEMP
|
||||||
|
|
||||||
|
LATEST_VERSION=$(latest_version)
|
||||||
|
echo "# Version: $LATEST_VERSION" >> $MAFIA_TEMP
|
||||||
|
|
||||||
|
if ! cmp $MAFIA_CUR $MAFIA_TEMP >/dev/null 2>&1; then
|
||||||
|
mv $MAFIA_TEMP $MAFIA_CUR
|
||||||
|
chmod +x $MAFIA_CUR
|
||||||
|
echo "New version found and upgraded. You can now commit it to your git repo."
|
||||||
|
else
|
||||||
|
echo "You have latest mafia."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
exec_mafia () {
|
||||||
|
MAFIA_VERSION=$(local_version)
|
||||||
|
|
||||||
|
if [ "x$MAFIA_VERSION" = "x" ]; then
|
||||||
|
# If we can't find the mafia version, then we need to upgrade the script.
|
||||||
|
run_upgrade
|
||||||
|
else
|
||||||
|
MAFIA_BIN=$HOME/.ambiata/mafia/bin
|
||||||
|
MAFIA_FILE=mafia-$MAFIA_VERSION
|
||||||
|
MAFIA_PATH=$MAFIA_BIN/$MAFIA_FILE
|
||||||
|
|
||||||
|
[ -f "$MAFIA_PATH" ] || {
|
||||||
|
# Create a temporary directory which will be deleted when the script
|
||||||
|
# terminates. Unfortunately `mktemp` doesn't behave the same on
|
||||||
|
# Linux and OS/X so we need to try two different approaches.
|
||||||
|
MAFIA_TEMP=$(mktemp -d 2>/dev/null || mktemp -d -t 'exec_mafia')
|
||||||
|
|
||||||
|
# Create a temporary file in MAFIA_BIN so we can do an atomic copy/move dance.
|
||||||
|
mkdir -p $MAFIA_BIN
|
||||||
|
MAFIA_PATH_TEMP=$(mktemp --tmpdir=$MAFIA_BIN $MAFIA_FILE-XXXXXX 2>/dev/null || TMPDIR=$MAFIA_BIN mktemp -t $MAFIA_FILE)
|
||||||
|
|
||||||
|
clean_up () {
|
||||||
|
rm -rf "$MAFIA_TEMP"
|
||||||
|
rm -f "$MAFIA_PATH_TEMP"
|
||||||
|
}
|
||||||
|
|
||||||
|
trap clean_up EXIT
|
||||||
|
|
||||||
|
echo "Building $MAFIA_FILE in $MAFIA_TEMP"
|
||||||
|
|
||||||
|
( cd "$MAFIA_TEMP"
|
||||||
|
|
||||||
|
git clone https://github.com/ambiata/mafia
|
||||||
|
cd mafia
|
||||||
|
|
||||||
|
git reset --hard $MAFIA_VERSION
|
||||||
|
|
||||||
|
bin/bootstrap ) || exit $?
|
||||||
|
|
||||||
|
cp "$MAFIA_TEMP/mafia/.cabal-sandbox/bin/mafia" "$MAFIA_PATH_TEMP"
|
||||||
|
chmod +x "$MAFIA_PATH_TEMP"
|
||||||
|
mv "$MAFIA_PATH_TEMP" "$MAFIA_PATH"
|
||||||
|
}
|
||||||
|
|
||||||
|
exec $MAFIA_PATH "$@"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
#
|
||||||
|
# The actual start of the script.....
|
||||||
|
#
|
||||||
|
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
MODE="$1"
|
||||||
|
else
|
||||||
|
MODE=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
case "$MODE" in
|
||||||
|
upgrade) shift; run_upgrade "$@" ;;
|
||||||
|
*) exec_mafia "$@"
|
||||||
|
esac
|
||||||
|
# Version: d64cd4f4ab42c1431752d7c84e355b7d001778f8
|
83
main/feedforward.hs
Normal file
83
main/feedforward.hs
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
|
||||||
|
import Control.Monad
|
||||||
|
import Control.Monad.Identity
|
||||||
|
import Control.Monad.Random
|
||||||
|
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as SA
|
||||||
|
|
||||||
|
import Options.Applicative
|
||||||
|
|
||||||
|
import Grenade
|
||||||
|
|
||||||
|
|
||||||
|
-- The defininition for our simple feed forward network.
|
||||||
|
-- The type level list represents the shapes passed through the layers. One can see that for this demonstration
|
||||||
|
-- we are using relu, tanh and logit non-linear units, which can be easily subsituted for each other in and out.
|
||||||
|
|
||||||
|
-- It's important to keep the type signatures, as there's many layers which can "squeeze" into the gaps
|
||||||
|
-- between the shapes, so inference can't do it all for us.
|
||||||
|
|
||||||
|
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
|
||||||
|
randomNet :: (MonadRandom m) => m (Network Identity '[('D1 2), ('D1 40), ('D1 40), ('D1 10), ('D1 10), ('D1 1), ('D1 1)])
|
||||||
|
randomNet = do
|
||||||
|
a :: FullyConnected 2 40 <- randomFullyConnected
|
||||||
|
b :: FullyConnected 40 10 <- randomFullyConnected
|
||||||
|
c :: FullyConnected 10 1 <- randomFullyConnected
|
||||||
|
return $ a :~> Tanh :~> b :~> Relu :~> c :~> O Logit
|
||||||
|
|
||||||
|
netTest :: MonadRandom m => Double -> Int -> m String
|
||||||
|
netTest rate n = do
|
||||||
|
inps <- replicateM n $ do
|
||||||
|
s <- getRandom
|
||||||
|
return $ S1D' $ SA.randomVector s SA.Uniform * 2 - 1
|
||||||
|
let outs = flip map inps $ \(S1D' v) ->
|
||||||
|
if v `inCircle` (fromRational 0.33, 0.33)
|
||||||
|
|| v `inCircle` (fromRational (-0.33), 0.33)
|
||||||
|
then S1D' $ fromRational 1
|
||||||
|
else S1D' $ fromRational 0
|
||||||
|
net0 <- randomNet
|
||||||
|
|
||||||
|
return . runIdentity $ do
|
||||||
|
trained <- foldM trainEach net0 (zip inps outs)
|
||||||
|
let testIns = [ [ (x,y) | x <- [0..50] ]
|
||||||
|
| y <- [0..20] ]
|
||||||
|
|
||||||
|
outMat <- traverse (traverse (\(x,y) -> (render . normx) <$> runNet trained (S1D' $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
|
||||||
|
return $ unlines outMat
|
||||||
|
|
||||||
|
where
|
||||||
|
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
|
||||||
|
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
|
||||||
|
trainEach !nt !(i, o) = train rate i o nt
|
||||||
|
|
||||||
|
render n' | n' <= 0.2 = ' '
|
||||||
|
| n' <= 0.4 = '.'
|
||||||
|
| n' <= 0.6 = '-'
|
||||||
|
| n' <= 0.8 = '='
|
||||||
|
| otherwise = '#'
|
||||||
|
|
||||||
|
normx :: S' ('D1 1) -> Double
|
||||||
|
normx (S1D' r) = SA.mean r
|
||||||
|
|
||||||
|
|
||||||
|
data FeedForwardOpts = FeedForwardOpts Int Double
|
||||||
|
|
||||||
|
feedForward' :: Parser FeedForwardOpts
|
||||||
|
feedForward' = FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 1000000)
|
||||||
|
<*> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do
|
||||||
|
FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm)
|
||||||
|
putStrLn "Training network..."
|
||||||
|
putStrLn =<< evalRandIO (netTest rate examples)
|
86
main/mnist.hs
Normal file
86
main/mnist.hs
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
|
||||||
|
import Control.Applicative
|
||||||
|
import Control.Monad
|
||||||
|
import Control.Monad.Identity
|
||||||
|
import Control.Monad.Random
|
||||||
|
|
||||||
|
import qualified Data.Attoparsec.Text as A
|
||||||
|
import qualified Data.Text as T
|
||||||
|
import qualified Data.Text.IO as T
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra (maxIndex)
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as SA
|
||||||
|
|
||||||
|
import Options.Applicative
|
||||||
|
|
||||||
|
import Grenade
|
||||||
|
|
||||||
|
-- The definition of our convolutional neural network.
|
||||||
|
-- In the type signature, we have a type level list of shapes which are passed between the layers.
|
||||||
|
-- One can see that the images we are inputing are two dimensional with 28 * 28 pixels.
|
||||||
|
|
||||||
|
-- It's important to keep the type signatures, as there's many layers which can "squeeze" into the gaps
|
||||||
|
-- between the shapes, so inference can't do it all for us.
|
||||||
|
|
||||||
|
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
|
||||||
|
-- this network should get down to about a 1.3% error rate.
|
||||||
|
randomMnistNet :: (MonadRandom m) => m (Network Identity '[('D2 28 28), ('D3 24 24 10), ('D3 12 12 10), ('D3 12 12 10), ('D3 8 8 16), ('D3 4 4 16), ('D1 256), ('D1 256), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
|
||||||
|
randomMnistNet = do
|
||||||
|
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
|
||||||
|
let b :: Pooling 2 2 2 2 = Pooling
|
||||||
|
c :: Convolution 10 16 5 5 1 1 <- randomConvolution
|
||||||
|
let d :: Pooling 2 2 2 2 = Pooling
|
||||||
|
e :: FullyConnected 256 80 <- randomFullyConnected
|
||||||
|
f :: FullyConnected 80 10 <- randomFullyConnected
|
||||||
|
return $ a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
|
||||||
|
|
||||||
|
convTest :: Int -> FilePath -> FilePath -> Double -> IO ()
|
||||||
|
convTest iterations trainFile validateFile rate = do
|
||||||
|
net0 <- evalRandIO randomMnistNet
|
||||||
|
fT <- T.readFile trainFile
|
||||||
|
fV <- T.readFile validateFile
|
||||||
|
let trainRows = traverse (A.parseOnly p) (T.lines fT)
|
||||||
|
let validateRows = traverse (A.parseOnly p) (T.lines fV)
|
||||||
|
case (trainRows, validateRows) of
|
||||||
|
(Right tr', Right vr') -> foldM_ (runIteration tr' vr') net0 [1..iterations]
|
||||||
|
err -> putStrLn $ show err
|
||||||
|
|
||||||
|
where
|
||||||
|
trainEach !rate' !nt !(i, o) = train rate' i o nt
|
||||||
|
|
||||||
|
p :: A.Parser (S' ('D2 28 28), S' ('D1 10))
|
||||||
|
p = do
|
||||||
|
lab <- A.decimal
|
||||||
|
pixels <- many (A.char ',' >> A.double)
|
||||||
|
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
|
||||||
|
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
|
||||||
|
|
||||||
|
runIteration trainRows validateRows net i = do
|
||||||
|
let trained' = runIdentity $ foldM (trainEach (rate * (0.9 ^ i))) net trainRows
|
||||||
|
let res = runIdentity $ traverse (\(rowP,rowL) -> (rowL,) <$> runNet trained' rowP) validateRows
|
||||||
|
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||||
|
putStrLn $ show trained'
|
||||||
|
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
||||||
|
return trained'
|
||||||
|
|
||||||
|
data MnistOpts = MnistOpts FilePath FilePath Int Double
|
||||||
|
|
||||||
|
mnist' :: Parser MnistOpts
|
||||||
|
mnist' = MnistOpts <$> (argument str (metavar "TRAIN"))
|
||||||
|
<*> (argument str (metavar "VALIDATE"))
|
||||||
|
<*> option auto (long "iterations" <> short 'i' <> value 15)
|
||||||
|
<*> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do
|
||||||
|
MnistOpts mnist vali iter rate <- execParser (info (mnist' <**> helper) idm)
|
||||||
|
putStrLn "Training convolutional neural network..."
|
||||||
|
convTest iter mnist vali rate
|
30
src/Grenade.hs
Normal file
30
src/Grenade.hs
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade (
|
||||||
|
module X
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Grenade.Core.Vector as X
|
||||||
|
import Grenade.Core.Network as X
|
||||||
|
import Grenade.Core.Runner as X
|
||||||
|
import Grenade.Core.Shape as X
|
||||||
|
import Grenade.Layers.Dropout as X
|
||||||
|
import Grenade.Layers.Pooling as X
|
||||||
|
import Grenade.Layers.Flatten as X
|
||||||
|
import Grenade.Layers.Fuse as X
|
||||||
|
import Grenade.Layers.FullyConnected as X
|
||||||
|
import Grenade.Layers.Logit as X
|
||||||
|
import Grenade.Layers.Convolution as X
|
||||||
|
import Grenade.Layers.Relu as X
|
||||||
|
import Grenade.Layers.Tanh as X
|
49
src/Grenade/Core/Network.hs
Normal file
49
src/Grenade/Core/Network.hs
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Core.Network (
|
||||||
|
Layer (..)
|
||||||
|
, Network (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | Class for a layer. All layers implement this, however, they don't
|
||||||
|
-- need to implement it for all shapes, only ones which are appropriate.
|
||||||
|
class Layer (m :: * -> *) x (i :: Shape) (o :: Shape) where
|
||||||
|
-- | Used in training and scoring. Take the input from the previous
|
||||||
|
-- layer, and give the output from this layer.
|
||||||
|
runForwards :: x -> S' i -> m (S' o)
|
||||||
|
-- | Back propagate a step. Takes a learning rate (move from here?)
|
||||||
|
-- the current layer, the input that the layer gave from the input
|
||||||
|
-- and the back propagated derivatives from the layer above.
|
||||||
|
-- Returns the updated layer and the derivatives to push back further.
|
||||||
|
runBackards :: Double -> x -> S' i -> S' o -> m (x, S' i)
|
||||||
|
|
||||||
|
-- | Type of a network.
|
||||||
|
-- The [Shape] type specifies the shapes of data passed between the layers.
|
||||||
|
-- Could be considered to be a heterogeneous list of layers which are able to
|
||||||
|
-- transform the date shapes of the network.
|
||||||
|
data Network :: (* -> *) -> [Shape] -> * where
|
||||||
|
O :: (Show x, Layer m x i o, KnownShape o, KnownShape i)
|
||||||
|
=> !x
|
||||||
|
-> Network m '[i, o]
|
||||||
|
(:~>) :: (Show x, Layer m x i h, KnownShape h, KnownShape i)
|
||||||
|
=> !x
|
||||||
|
-> !(Network m (h ': hs))
|
||||||
|
-> Network m (i ': h ': hs)
|
||||||
|
infixr 5 :~>
|
||||||
|
|
||||||
|
instance Show (Network m h) where
|
||||||
|
show (O a) = "O " ++ show a
|
||||||
|
show (i :~> o) = show i ++ "\n:~>\n" ++ show o
|
55
src/Grenade/Core/Runner.hs
Normal file
55
src/Grenade/Core/Runner.hs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
|
module Grenade.Core.Runner (
|
||||||
|
train
|
||||||
|
, runNet
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Singletons.Prelude
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | Update a network with new weights after training with an instance.
|
||||||
|
train :: forall m i o hs. (Monad m, Head hs ~ i, Last hs ~ o, KnownShape i, KnownShape o)
|
||||||
|
=> Double -- ^ learning rate
|
||||||
|
-> S' i -- ^ input vector
|
||||||
|
-> S' o -- ^ target vector
|
||||||
|
-> Network m hs -- ^ network to train
|
||||||
|
-> m (Network m hs)
|
||||||
|
train rate x0 target = fmap fst . go x0
|
||||||
|
where
|
||||||
|
go :: forall m' j js. (Monad m', Head js ~ j, Last js ~ o, KnownShape j, KnownShape o)
|
||||||
|
=> S' j -- ^ input vector
|
||||||
|
-> Network m' js -- ^ network to train
|
||||||
|
-> m' (Network m' js, S' j)
|
||||||
|
-- handle input from the beginning, feeding upwards.
|
||||||
|
go !x (layer :~> n)
|
||||||
|
= do y <- runForwards layer x
|
||||||
|
-- run the rest of the network, and get the layer from above.
|
||||||
|
(n', dWs') <- go y n
|
||||||
|
-- calculate the gradient for this layer to pass down,
|
||||||
|
(layer', dWs) <- runBackards rate layer x dWs'
|
||||||
|
return (layer' :~> n', dWs)
|
||||||
|
|
||||||
|
-- handle the output layer, bouncing the derivatives back down.
|
||||||
|
go !x (O layer)
|
||||||
|
= do y <- runForwards layer x
|
||||||
|
-- the gradient (how much y affects the error)
|
||||||
|
(layer', dWs) <- runBackards rate layer x (y - target)
|
||||||
|
return (O layer', dWs)
|
||||||
|
|
||||||
|
-- | Just forwards propagation with no training.
|
||||||
|
runNet :: forall m hs. (Monad m)
|
||||||
|
=> Network m hs
|
||||||
|
-> (S' (Head hs)) -- ^ input vector
|
||||||
|
-> m (S' (Last hs)) -- ^ target vector
|
||||||
|
runNet (layer :~> n) !x = do y <- runForwards layer x
|
||||||
|
runNet n y
|
||||||
|
runNet (O layer) !x = runForwards layer x
|
83
src/Grenade/Core/Shape.hs
Normal file
83
src/Grenade/Core/Shape.hs
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Core.Shape (
|
||||||
|
Shape (..)
|
||||||
|
, S' (..)
|
||||||
|
, KnownShape (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
import Data.Proxy
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
|
import Grenade.Core.Vector
|
||||||
|
|
||||||
|
-- | The current shapes we accept.
|
||||||
|
-- at the moment this is just one, two, and three dimensional
|
||||||
|
-- Vectors/Matricies.
|
||||||
|
data Shape =
|
||||||
|
D1 Nat
|
||||||
|
| D2 Nat Nat
|
||||||
|
| D3 Nat Nat Nat
|
||||||
|
|
||||||
|
instance KnownShape x => Num (S' x) where
|
||||||
|
(+) (S1D' x) (S1D' y) = S1D' (x + y)
|
||||||
|
(+) (S2D' x) (S2D' y) = S2D' (x + y)
|
||||||
|
(+) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' + y') x y)
|
||||||
|
(+) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||||
|
|
||||||
|
(-) (S1D' x) (S1D' y) = S1D' (x - y)
|
||||||
|
(-) (S2D' x) (S2D' y) = S2D' (x - y)
|
||||||
|
(-) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' - y') x y)
|
||||||
|
(-) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||||
|
|
||||||
|
(*) (S1D' x) (S1D' y) = S1D' (x * y)
|
||||||
|
(*) (S2D' x) (S2D' y) = S2D' (x * y)
|
||||||
|
(*) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' * y') x y)
|
||||||
|
(*) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||||
|
|
||||||
|
abs (S1D' x) = S1D' (abs x)
|
||||||
|
abs (S2D' x) = S2D' (abs x)
|
||||||
|
abs (S3D' x) = S3D' (fmap abs x)
|
||||||
|
|
||||||
|
signum (S1D' x) = S1D' (signum x)
|
||||||
|
signum (S2D' x) = S2D' (signum x)
|
||||||
|
signum (S3D' x) = S3D' (fmap signum x)
|
||||||
|
|
||||||
|
fromInteger _ = error "Unimplemented: fromInteger on Shape"
|
||||||
|
|
||||||
|
-- | Given a Shape n, these are the possible data structures with that shape.
|
||||||
|
data S' (n :: Shape) where
|
||||||
|
S1D' :: (KnownNat o) => R o -> S' ('D1 o)
|
||||||
|
S2D' :: (KnownNat rows, KnownNat columns) => L rows columns -> S' ('D2 rows columns)
|
||||||
|
S3D' :: (KnownNat rows, KnownNat columns, KnownNat depth) => Vector depth (L rows columns) -> S' ('D3 rows columns depth)
|
||||||
|
|
||||||
|
instance Show (S' n) where
|
||||||
|
show (S1D' a) = "S1D' " ++ show a
|
||||||
|
show (S2D' a) = "S2D' " ++ show a
|
||||||
|
show (S3D' a) = "S3D' " ++ show a
|
||||||
|
|
||||||
|
-- | Singleton for Shape
|
||||||
|
class KnownShape (n :: Shape) where
|
||||||
|
shapeSing :: Proxy n
|
||||||
|
|
||||||
|
instance KnownShape ('D1 n) where
|
||||||
|
shapeSing = Proxy
|
||||||
|
|
||||||
|
instance KnownShape ('D2 n m) where
|
||||||
|
shapeSing = Proxy
|
||||||
|
|
||||||
|
instance KnownShape ('D3 l n m) where
|
||||||
|
shapeSing = Proxy
|
53
src/Grenade/Core/Vector.hs
Normal file
53
src/Grenade/Core/Vector.hs
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
|
||||||
|
module Grenade.Core.Vector (
|
||||||
|
Vector
|
||||||
|
, vectorZip
|
||||||
|
, vecToList
|
||||||
|
, mkVector
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Proxy
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
-- | A more specific Tagged type, ensuring that a list
|
||||||
|
-- is equal to the Nat value.
|
||||||
|
newtype Vector (n :: Nat) a = Vector [a]
|
||||||
|
|
||||||
|
instance Foldable (Vector n) where
|
||||||
|
foldr f b (Vector as) = foldr f b as
|
||||||
|
|
||||||
|
instance KnownNat n => Traversable (Vector n) where
|
||||||
|
traverse f (Vector as) = fmap mkVector $ traverse f as
|
||||||
|
|
||||||
|
instance Functor (Vector n) where
|
||||||
|
fmap f (Vector as) = Vector (fmap f as)
|
||||||
|
|
||||||
|
instance Show a => Show (Vector n a) where
|
||||||
|
showsPrec d = showsPrec d . vecToList
|
||||||
|
|
||||||
|
instance Eq a => Eq (Vector n a) where
|
||||||
|
(Vector as) == (Vector bs) = as == bs
|
||||||
|
|
||||||
|
mkVector :: forall n a. KnownNat n => [a] -> Vector n a
|
||||||
|
mkVector as
|
||||||
|
= let du = fromIntegral . natVal $ (undefined :: Proxy n)
|
||||||
|
la = length as
|
||||||
|
in if (du == la)
|
||||||
|
then Vector as
|
||||||
|
else error $ "Error creating staticly sized Vector of length: " ++
|
||||||
|
show du ++ " list is of length:" ++ show la
|
||||||
|
|
||||||
|
vecToList :: Vector n a -> [a]
|
||||||
|
vecToList (Vector as) = as
|
||||||
|
|
||||||
|
vectorZip :: (a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
|
||||||
|
vectorZip f (Vector as) (Vector bs) = Vector (zipWith f as bs)
|
276
src/Grenade/Layers/Convolution.hs
Normal file
276
src/Grenade/Layers/Convolution.hs
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
{-# LANGUAGE PatternGuards #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.Convolution (
|
||||||
|
Convolution (..)
|
||||||
|
, randomConvolution
|
||||||
|
, im2col
|
||||||
|
, vid2col
|
||||||
|
, col2im
|
||||||
|
, col2vid
|
||||||
|
, fittingStarts
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Random hiding (fromList)
|
||||||
|
import Data.Maybe
|
||||||
|
import Data.Proxy
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra hiding (uniformSample, konst)
|
||||||
|
import qualified Numeric.LinearAlgebra as LA
|
||||||
|
import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
|
||||||
|
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Core.Vector
|
||||||
|
|
||||||
|
-- | A convolution layer for a neural network.
|
||||||
|
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
||||||
|
-- many, many, many, many loop convolution into a single matrix multiplication.
|
||||||
|
--
|
||||||
|
-- The convolution layer takes all of the kernels for the convolution, which are flattened
|
||||||
|
-- and then put into columns in the matrix.
|
||||||
|
--
|
||||||
|
-- The kernel size dictates which input and output sizes will "fit". Fitting the equation:
|
||||||
|
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
||||||
|
--
|
||||||
|
-- One probably shouldn't build their own layer, but rather use the randomConvolution function.
|
||||||
|
data Convolution :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance.
|
||||||
|
-> Nat -- ^ Number of filters, this is the number of channels output by the layer.
|
||||||
|
-> Nat -- ^ The number of rows in the kernel filter
|
||||||
|
-> Nat -- ^ The number of column in the kernel filter
|
||||||
|
-> Nat -- ^ The row stride of the convolution filter
|
||||||
|
-> Nat -- ^ The columns stride of the convolution filter
|
||||||
|
-> * where
|
||||||
|
Convolution :: ( KnownNat channels
|
||||||
|
, KnownNat filters
|
||||||
|
, KnownNat kernelRows
|
||||||
|
, KnownNat kernelColumns
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideColumns
|
||||||
|
, KnownNat kernelFlattened
|
||||||
|
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||||
|
=> !(L kernelFlattened filters) -- ^ The kernel filter weights
|
||||||
|
-> !(L kernelFlattened filters) -- ^ The last kernel update (or momentum)
|
||||||
|
-> Convolution channels filters kernelRows kernelColumns strideRows strideColumns
|
||||||
|
|
||||||
|
instance Show (Convolution c f k k' s s') where
|
||||||
|
show (Convolution a _) = renderConv a
|
||||||
|
where
|
||||||
|
renderConv mm =
|
||||||
|
let m = extract mm
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy k)
|
||||||
|
rs = LA.toColumns m
|
||||||
|
ms = map (take ky) $ toLists . reshape ky <$> rs
|
||||||
|
|
||||||
|
render n' | n' <= 0.2 = ' '
|
||||||
|
| n' <= 0.4 = '.'
|
||||||
|
| n' <= 0.6 = '-'
|
||||||
|
| n' <= 0.8 = '='
|
||||||
|
| otherwise = '#'
|
||||||
|
|
||||||
|
px = (fmap . fmap . fmap) render ms
|
||||||
|
in unlines $ foldl1 (zipWith (\a' b' -> a' ++ " | " ++ b')) $ px
|
||||||
|
|
||||||
|
randomConvolution :: ( MonadRandom m
|
||||||
|
, KnownNat channels
|
||||||
|
, KnownNat filters
|
||||||
|
, KnownNat kernelRows
|
||||||
|
, KnownNat kernelColumns
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideColumns
|
||||||
|
, KnownNat kernelFlattened
|
||||||
|
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||||
|
=> m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||||
|
randomConvolution = do
|
||||||
|
s :: Int <- getRandom
|
||||||
|
let wN = uniformSample s (-1) 1
|
||||||
|
mm = konst 0
|
||||||
|
return $ Convolution wN mm
|
||||||
|
|
||||||
|
-- | A two dimentional image may have a convolution filter applied to it
|
||||||
|
instance ( Monad m
|
||||||
|
, KnownNat kernelRows
|
||||||
|
, KnownNat kernelCols
|
||||||
|
, KnownNat filters
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideCols
|
||||||
|
, KnownNat inputRows
|
||||||
|
, KnownNat inputCols
|
||||||
|
, KnownNat outputRows
|
||||||
|
, KnownNat outputCols
|
||||||
|
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||||
|
, ((outputCols - 1) * strideCols) ~ (inputCols - kernelCols)
|
||||||
|
) => Layer m (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where
|
||||||
|
runForwards (Convolution kernel _) (S2D' input) =
|
||||||
|
let ex = extract input
|
||||||
|
ek = extract kernel
|
||||||
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
|
c = im2col kx ky sx sy ex
|
||||||
|
mt = c LA.<> ek
|
||||||
|
r = col2vid 1 1 1 1 ox oy mt
|
||||||
|
rs = fmap (fromJust . create) r
|
||||||
|
in return . S3D' $ mkVector rs
|
||||||
|
runBackards rate (Convolution kernel momentum) (S2D' input) (S3D' dEdy) =
|
||||||
|
let ex = extract input
|
||||||
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
|
c = im2col kx ky sx sy ex
|
||||||
|
|
||||||
|
eo = vecToList $ fmap extract dEdy
|
||||||
|
ek = extract kernel
|
||||||
|
|
||||||
|
vs = vid2col 1 1 1 1 ox oy eo
|
||||||
|
|
||||||
|
kN = fromJust . create $ tr c LA.<> vs
|
||||||
|
mm = momentum * 0.9 - konst rate * kN
|
||||||
|
wd = konst (0.0005 * rate) * kernel
|
||||||
|
rM = kernel + mm - wd
|
||||||
|
|
||||||
|
dW = vs LA.<> tr ek
|
||||||
|
|
||||||
|
xW = col2im kx ky sx sy ix iy dW
|
||||||
|
in return (Convolution rM mm, S2D' . fromJust . create $ xW)
|
||||||
|
|
||||||
|
|
||||||
|
-- | A three dimensional image (or 2d with many channels) can have
|
||||||
|
-- an appropriately sized convolution filter run across it.
|
||||||
|
instance ( Monad m
|
||||||
|
, KnownNat kernelRows
|
||||||
|
, KnownNat kernelCols
|
||||||
|
, KnownNat filters
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideCols
|
||||||
|
, KnownNat inputRows
|
||||||
|
, KnownNat inputCols
|
||||||
|
, KnownNat outputRows
|
||||||
|
, KnownNat outputCols
|
||||||
|
, KnownNat channels
|
||||||
|
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||||
|
, ((outputCols - 1) * strideCols) ~ (inputCols - kernelCols)
|
||||||
|
) => Layer m (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where
|
||||||
|
runForwards (Convolution kernel _) (S3D' input) =
|
||||||
|
let ex = vecToList $ fmap extract input
|
||||||
|
ek = extract kernel
|
||||||
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
|
c = vid2col kx ky sx sy ix iy ex
|
||||||
|
mt = c LA.<> ek
|
||||||
|
r = col2vid 1 1 1 1 ox oy mt
|
||||||
|
rs = fmap (fromJust . create) r
|
||||||
|
in return . S3D' $ mkVector rs
|
||||||
|
runBackards rate (Convolution kernel momentum) (S3D' input) (S3D' dEdy) =
|
||||||
|
let ex = vecToList $ fmap extract input
|
||||||
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
|
c = vid2col kx ky sx sy ix iy ex
|
||||||
|
|
||||||
|
eo = vecToList $ fmap extract dEdy
|
||||||
|
ek = extract kernel
|
||||||
|
|
||||||
|
vs = vid2col 1 1 1 1 ox oy eo
|
||||||
|
|
||||||
|
kN = fromJust . create $ tr c LA.<> vs
|
||||||
|
mm = momentum * 0.9 - konst rate * kN
|
||||||
|
wd = konst (0.0005 * rate) * kernel
|
||||||
|
rM = kernel + mm - wd
|
||||||
|
|
||||||
|
dW = vs LA.<> tr ek
|
||||||
|
|
||||||
|
xW = col2vid kx ky sx sy ix iy dW
|
||||||
|
in return (Convolution rM mm, S3D' . mkVector . fmap (fromJust . create) $ xW)
|
||||||
|
|
||||||
|
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
im2col nrows ncols srows scols m =
|
||||||
|
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||||
|
in im2colFit starts nrows ncols m
|
||||||
|
|
||||||
|
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
im2colFit starts nrows ncols m =
|
||||||
|
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
|
||||||
|
in fromRows imRows
|
||||||
|
|
||||||
|
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||||
|
vid2col nrows ncols srows scols inputrows inputcols ms =
|
||||||
|
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
|
||||||
|
subs = fmap (im2colFit starts nrows ncols) ms
|
||||||
|
in foldl1 (|||) subs
|
||||||
|
|
||||||
|
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||||
|
col2vid nrows ncols srows scols drows dcols m =
|
||||||
|
let starts = fittingStart (cols m) (nrows * ncols) (nrows * ncols)
|
||||||
|
r = rows m
|
||||||
|
mats = fmap (\s -> subMatrix (0,s) (r, nrows * ncols) m) starts
|
||||||
|
colSts = fittingStarts drows nrows srows dcols ncols scols
|
||||||
|
in fmap (col2imfit colSts nrows ncols drows dcols) mats
|
||||||
|
|
||||||
|
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
col2im krows kcols srows scols drows dcols m =
|
||||||
|
let starts = fittingStarts drows krows srows dcols kcols scols
|
||||||
|
in col2imfit starts krows kcols drows dcols m
|
||||||
|
|
||||||
|
col2imfit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
col2imfit starts krows kcols drows dcols m =
|
||||||
|
let indicies = fmap (\[a,b] -> (a,b)) $ sequence [[0..(krows-1)], [0..(kcols-1)]]
|
||||||
|
convs = fmap (zip indicies . toList) . toRows $ m
|
||||||
|
pairs = zip convs starts
|
||||||
|
accums = concat $ fmap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
|
||||||
|
in accum (LA.konst 0 (drows, dcols)) (+) accums
|
||||||
|
|
||||||
|
|
||||||
|
-- | These functions are not even remotely safe, but it's only called from the statically typed
|
||||||
|
-- commands, so we should be good ?!?!?
|
||||||
|
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||||
|
-- convolution. Takes into account the stride and kernel size.
|
||||||
|
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
||||||
|
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
||||||
|
let rs = fittingStart nrows kernelrows steprows
|
||||||
|
cs = fittingStart ncols kernelcols stepcolsh
|
||||||
|
ls = sequence [rs, cs]
|
||||||
|
in fmap (\[a,b] -> (a,b)) ls
|
||||||
|
|
||||||
|
-- | Returns the starting sub vector which fit inside the larger vector for the
|
||||||
|
-- convolution. Takes into account the stride and kernel size.
|
||||||
|
fittingStart :: Int -> Int -> Int -> [Int]
|
||||||
|
fittingStart width kernel steps =
|
||||||
|
let go left | left + kernel < width
|
||||||
|
= left : go (left + steps)
|
||||||
|
| left + kernel == width
|
||||||
|
= left : []
|
||||||
|
| otherwise
|
||||||
|
= error "Kernel and step do not fit in matrix."
|
||||||
|
in go 0
|
32
src/Grenade/Layers/Dropout.hs
Normal file
32
src/Grenade/Layers/Dropout.hs
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.Dropout (
|
||||||
|
Dropout (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Random hiding (fromList)
|
||||||
|
|
||||||
|
import GHC.TypeLits
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Core.Network
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
|
-- Dropout layer help to reduce overfitting.
|
||||||
|
-- Idea here is that the vector is a shape of 1s and 0s, which we multiply the input by.
|
||||||
|
-- After backpropogation, we return a new matrix/vector, with different bits dropped out.
|
||||||
|
-- Double is the proportion to drop in each training iteration (like 1% or 5% would be
|
||||||
|
-- reasonable).
|
||||||
|
data Dropout o = Dropout Double (R o)
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance (MonadRandom m, KnownNat i) => Layer m (Dropout i) ('D1 i) ('D1 i) where
|
||||||
|
runForwards _ _= error "todo"
|
||||||
|
runBackards _ _ _ _ = error "todo"
|
45
src/Grenade/Layers/Flatten.hs
Normal file
45
src/Grenade/Layers/Flatten.hs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.Flatten (
|
||||||
|
FlattenLayer (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Proxy
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
import Numeric.LinearAlgebra.Data as LA (flatten, toList, takesV, reshape, vjoin)
|
||||||
|
|
||||||
|
import Grenade.Core.Vector
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Core.Network
|
||||||
|
|
||||||
|
data FlattenLayer = FlattenLayer
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer m FlattenLayer ('D2 x y) ('D1 a) where
|
||||||
|
runForwards _ (S2D' y) = return $ S1D' . fromList . toList . flatten . extract $ y
|
||||||
|
runBackards _ _ _ (S1D' y) = return (FlattenLayer, S2D' . fromList . toList . unwrap $ y)
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat a, KnownNat x, KnownNat y, KnownNat z, a ~ (x * y * z)) => Layer m FlattenLayer ('D3 x y z) ('D1 a) where
|
||||||
|
runForwards _ (S3D' y) = return $ S1D' . raiseShapeError . create . vjoin . vecToList . fmap (flatten . extract) $ y
|
||||||
|
runBackards _ _ _ (S1D' o) = do
|
||||||
|
let x' = fromIntegral $ natVal (Proxy :: Proxy x)
|
||||||
|
y' = fromIntegral $ natVal (Proxy :: Proxy y)
|
||||||
|
z' = fromIntegral $ natVal (Proxy :: Proxy z)
|
||||||
|
vecs = takesV (replicate z' (x' * y')) (extract o)
|
||||||
|
ls = fmap (raiseShapeError . create . reshape y') vecs
|
||||||
|
ls' = mkVector ls :: Vector z (L x y)
|
||||||
|
return (FlattenLayer, S3D' ls')
|
||||||
|
|
||||||
|
raiseShapeError :: Maybe a -> a
|
||||||
|
raiseShapeError (Just x) = x
|
||||||
|
raiseShapeError Nothing = error "Static shape creation from Flatten layer produced the wrong result"
|
56
src/Grenade/Layers/FullyConnected.hs
Normal file
56
src/Grenade/Layers/FullyConnected.hs
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.FullyConnected (
|
||||||
|
FullyConnected (..)
|
||||||
|
, randomFullyConnected
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Random hiding (fromList)
|
||||||
|
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | A basic fully connected (or inner product) neural network layer.
|
||||||
|
data FullyConnected i o = FullyConnected
|
||||||
|
!(R o) -- Bias neuron weights
|
||||||
|
!(L o i) -- Activation weights
|
||||||
|
!(L o i) -- Activation momentums
|
||||||
|
|
||||||
|
instance Show (FullyConnected i o) where
|
||||||
|
show (FullyConnected _ _ _) = "FullyConnected"
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i, KnownNat o) => Layer m (FullyConnected i o) ('D1 i) ('D1 o) where
|
||||||
|
-- Do a matrix vector multiplication and return the result.
|
||||||
|
runForwards (FullyConnected wB wN _) (S1D' v) = return $ S1D' (wB + wN #> v)
|
||||||
|
|
||||||
|
-- Run a backpropogation step for a full connected layer.
|
||||||
|
runBackards rate (FullyConnected wB wN mm) (S1D' x) (S1D' dEdy) =
|
||||||
|
let wB' = wB - konst rate * dEdy
|
||||||
|
mm' = 0.9 * mm - konst rate * (dEdy `outer` x)
|
||||||
|
wd' = konst (0.0005 * rate) * wN
|
||||||
|
wN' = wN + mm' - wd'
|
||||||
|
w' = FullyConnected wB' wN' mm'
|
||||||
|
-- calcluate derivatives for next step
|
||||||
|
dWs = tr wN #> dEdy
|
||||||
|
in return (w', S1D' dWs)
|
||||||
|
|
||||||
|
randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o)
|
||||||
|
=> m (FullyConnected i o)
|
||||||
|
randomFullyConnected = do
|
||||||
|
s1 :: Int <- getRandom
|
||||||
|
s2 :: Int <- getRandom
|
||||||
|
let wB = randomVector s1 Uniform * 2 - 1
|
||||||
|
wN = uniformSample s2 (-1) 1
|
||||||
|
mm = konst 0
|
||||||
|
return $ FullyConnected wB wN mm
|
45
src/Grenade/Layers/Fuse.hs
Normal file
45
src/Grenade/Layers/Fuse.hs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
|
||||||
|
module Grenade.Layers.Fuse (
|
||||||
|
Fuse (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | Fuse two layers into one layer.
|
||||||
|
-- This can be used to simplify a network if a complicated repeated structure is used.
|
||||||
|
-- This does however have a trade off, internal incremental states in the Wengert tape are
|
||||||
|
-- not retained during reverse accumulation. So less RAM is used, but more compute is required.
|
||||||
|
data Fuse :: (* -> *) -> Shape -> Shape -> Shape -> * where
|
||||||
|
(:$$) :: (Show x, Show y, Layer m x i h, Layer m y h o, KnownShape h, KnownShape i, KnownShape o)
|
||||||
|
=> !x
|
||||||
|
-> !y
|
||||||
|
-> Fuse m i h o
|
||||||
|
infixr 5 :$$
|
||||||
|
|
||||||
|
instance Show (Fuse m i h o) where
|
||||||
|
show (x :$$ y) = "(" ++ show x ++ " :$$ " ++ show y ++ ")"
|
||||||
|
|
||||||
|
instance (Monad m, KnownShape i, KnownShape h, KnownShape o) => Layer m (Fuse m i h o) i o where
|
||||||
|
runForwards (x :$$ y) input = do
|
||||||
|
yInput :: S' h <- runForwards x input
|
||||||
|
runForwards y yInput
|
||||||
|
|
||||||
|
runBackards rate (x :$$ y) input backGradient = do
|
||||||
|
yInput :: S' h <- runForwards x input
|
||||||
|
(y', yGrad) <- runBackards rate y yInput backGradient
|
||||||
|
(x', xGrad) <- runBackards rate x input yGrad
|
||||||
|
return (x' :$$ y', xGrad)
|
45
src/Grenade/Layers/Logit.hs
Normal file
45
src/Grenade/Layers/Logit.hs
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.Logit (
|
||||||
|
Logit (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Vector
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | A Logit layer.
|
||||||
|
-- A layer which can act between any shape of the same dimension, perfoming an sigmoid function.
|
||||||
|
-- This layer should be used as the output layer of a network for logistic regression (classification)
|
||||||
|
-- problems.
|
||||||
|
data Logit = Logit
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i) => Layer m Logit ('D1 i) ('D1 i) where
|
||||||
|
runForwards _ (S1D' y) = return $ S1D' (logistic y)
|
||||||
|
runBackards _ _ (S1D' y) (S1D' dEdy) = return (Logit, S1D' (logistic' y * dEdy))
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i, KnownNat j) => Layer m Logit ('D2 i j) ('D2 i j) where
|
||||||
|
runForwards _ (S2D' y) = return $ S2D' (logistic y)
|
||||||
|
runBackards _ _ (S2D' y) (S2D' dEdy) = return (Logit, S2D' (logistic' y * dEdy))
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i, KnownNat j, KnownNat k) => Layer m Logit ('D3 i j k) ('D3 i j k) where
|
||||||
|
runForwards _ (S3D' y) = return $ S3D' (fmap logistic y)
|
||||||
|
runBackards _ _ (S3D' y) (S3D' dEdy) = return (Logit, S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy))
|
||||||
|
|
||||||
|
|
||||||
|
logistic :: Floating a => a -> a
|
||||||
|
logistic x = 1 / (1 + exp (-x))
|
||||||
|
|
||||||
|
logistic' :: Floating a => a -> a
|
||||||
|
logistic' x = logix * (1 - logix)
|
||||||
|
where
|
||||||
|
logix = logistic x
|
163
src/Grenade/Layers/Pooling.hs
Normal file
163
src/Grenade/Layers/Pooling.hs
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.Pooling (
|
||||||
|
Pooling (..)
|
||||||
|
, poolForward
|
||||||
|
, poolBackward
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Maybe
|
||||||
|
import Data.Proxy
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Core.Vector
|
||||||
|
import Grenade.Layers.Convolution
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra hiding (uniformSample)
|
||||||
|
import qualified Numeric.LinearAlgebra as LA
|
||||||
|
import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows)
|
||||||
|
|
||||||
|
-- | A pooling layer for a neural network.
|
||||||
|
-- Does a max pooling, looking over a kernel similarly to the convolution network, but returning
|
||||||
|
-- maxarg only. This layer is often used to provide minor amounts of translational invariance.
|
||||||
|
--
|
||||||
|
-- The kernel size dictates which input and output sizes will "fit". Fitting the equation:
|
||||||
|
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
||||||
|
--
|
||||||
|
data Pooling :: Nat
|
||||||
|
-> Nat
|
||||||
|
-> Nat
|
||||||
|
-> Nat -> * where
|
||||||
|
Pooling :: ( KnownNat kernelRows
|
||||||
|
, KnownNat kernelColumns
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideColumns
|
||||||
|
) => Pooling kernelRows kernelColumns strideRows strideColumns
|
||||||
|
|
||||||
|
instance Show (Pooling k k' s s') where
|
||||||
|
show Pooling = "Pooling"
|
||||||
|
|
||||||
|
|
||||||
|
-- | A two dimentional image can be pooled.
|
||||||
|
instance ( Monad m
|
||||||
|
, KnownNat kernelRows
|
||||||
|
, KnownNat kernelColumns
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideColumns
|
||||||
|
, KnownNat inputRows
|
||||||
|
, KnownNat inputColumns
|
||||||
|
, KnownNat outputRows
|
||||||
|
, KnownNat outputColumns
|
||||||
|
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||||
|
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||||
|
) => Layer m (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||||
|
runForwards Pooling (S2D' input) =
|
||||||
|
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||||
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
||||||
|
ex = extract input
|
||||||
|
r = poolForward kx ky sx sy ox oy $ ex
|
||||||
|
rs = fromJust . create $ r
|
||||||
|
in return . S2D' $ rs
|
||||||
|
runBackards _ Pooling (S2D' input) (S2D' dEdy) =
|
||||||
|
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||||
|
ex = extract input
|
||||||
|
eo = extract dEdy
|
||||||
|
vs = poolBackward kx ky sx sy ex eo
|
||||||
|
in return (Pooling, S2D' . fromJust . create $ vs)
|
||||||
|
|
||||||
|
|
||||||
|
-- | A three dimensional image can be pooled on each layer.
|
||||||
|
instance ( Monad m
|
||||||
|
, KnownNat kernelRows
|
||||||
|
, KnownNat kernelColumns
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideColumns
|
||||||
|
, KnownNat inputRows
|
||||||
|
, KnownNat inputColumns
|
||||||
|
, KnownNat outputRows
|
||||||
|
, KnownNat outputColumns
|
||||||
|
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||||
|
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||||
|
) => Layer m (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
|
||||||
|
runForwards Pooling (S3D' input) =
|
||||||
|
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||||
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
||||||
|
ex = fmap extract input
|
||||||
|
r = poolForwardList kx ky sx sy ix iy ox oy ex
|
||||||
|
rs = fmap (fromJust . create) r
|
||||||
|
in return . S3D' $ rs
|
||||||
|
runBackards _ Pooling (S3D' input) (S3D' dEdy) =
|
||||||
|
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||||
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
|
||||||
|
ex = fmap extract input
|
||||||
|
eo = fmap extract dEdy
|
||||||
|
ez = vectorZip (,) ex eo
|
||||||
|
vs = poolBackwardList kx ky sx sy ix iy ez
|
||||||
|
in return (Pooling, S3D' . fmap (fromJust . create) $ vs)
|
||||||
|
|
||||||
|
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
poolForward nrows ncols srows scols outputRows outputCols m =
|
||||||
|
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||||
|
in poolForwardFit starts nrows ncols outputRows outputCols m
|
||||||
|
|
||||||
|
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
|
||||||
|
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||||
|
let starts = fittingStarts inRows nrows srows inCols ncols scols
|
||||||
|
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
|
||||||
|
|
||||||
|
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
poolForwardFit starts nrows ncols _ outputCols m =
|
||||||
|
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
|
||||||
|
in LA.matrix outputCols els
|
||||||
|
|
||||||
|
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||||
|
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
|
||||||
|
let inRows = (rows inputMatrix)
|
||||||
|
inCols = (cols inputMatrix)
|
||||||
|
starts = fittingStarts inRows krows srows inCols kcols scols
|
||||||
|
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
|
||||||
|
|
||||||
|
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
|
||||||
|
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||||
|
let starts = fittingStarts inRows krows srows inCols kcols scols
|
||||||
|
in (uncurry $ poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||||
|
|
||||||
|
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||||
|
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||||
|
let inRows = (rows inputMatrix)
|
||||||
|
inCols = (cols inputMatrix)
|
||||||
|
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||||
|
grads = toList $ flatten gradientMatrix
|
||||||
|
grads' = zip3 starts grads inds
|
||||||
|
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||||
|
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
49
src/Grenade/Layers/Relu.hs
Normal file
49
src/Grenade/Layers/Relu.hs
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.Relu (
|
||||||
|
Relu (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import GHC.TypeLits
|
||||||
|
import Grenade.Core.Vector
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as LAS
|
||||||
|
|
||||||
|
-- | A rectifying linear unit.
|
||||||
|
-- A layer which can act between any shape of the same dimension, acting as a
|
||||||
|
-- diode on every neuron individually.
|
||||||
|
data Relu = Relu
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i) => Layer m Relu ('D1 i) ('D1 i) where
|
||||||
|
runForwards _ (S1D' y) = return $ S1D' (relu y)
|
||||||
|
where
|
||||||
|
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
|
||||||
|
runBackards _ _ (S1D' y) (S1D' dEdy) = return (Relu, S1D' (relu' y * dEdy))
|
||||||
|
where
|
||||||
|
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i, KnownNat j) => Layer m Relu ('D2 i j) ('D2 i j) where
|
||||||
|
runForwards _ (S2D' y) = return $ S2D' (relu y)
|
||||||
|
where
|
||||||
|
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||||
|
runBackards _ _ (S2D' y) (S2D' dEdy) = return (Relu, S2D' (relu' y * dEdy))
|
||||||
|
where
|
||||||
|
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i, KnownNat j, KnownNat k) => Layer m Relu ('D3 i j k) ('D3 i j k) where
|
||||||
|
runForwards _ (S3D' y) = return $ S3D' (fmap relu y)
|
||||||
|
where
|
||||||
|
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||||
|
runBackards _ _ (S3D' y) (S3D' dEdy) = return (Relu, S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy))
|
||||||
|
where
|
||||||
|
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
37
src/Grenade/Layers/Tanh.hs
Normal file
37
src/Grenade/Layers/Tanh.hs
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
module Grenade.Layers.Tanh (
|
||||||
|
Tanh (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import GHC.TypeLits
|
||||||
|
import Grenade.Core.Vector
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | A Tanh layer.
|
||||||
|
-- A layer which can act between any shape of the same dimension, perfoming an tanh function.s
|
||||||
|
data Tanh = Tanh
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i) => Layer m Tanh ('D1 i) ('D1 i) where
|
||||||
|
runForwards _ (S1D' y) = return $ S1D' (tanh y)
|
||||||
|
runBackards _ _ (S1D' y) (S1D' dEdy) = return (Tanh, S1D' (tanh' y * dEdy))
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i, KnownNat j) => Layer m Tanh ('D2 i j) ('D2 i j) where
|
||||||
|
runForwards _ (S2D' y) = return $ S2D' (tanh y)
|
||||||
|
runBackards _ _ (S2D' y) (S2D' dEdy) = return (Tanh, S2D' (tanh' y * dEdy))
|
||||||
|
|
||||||
|
instance (Monad m, KnownNat i, KnownNat j, KnownNat k) => Layer m Tanh ('D3 i j k) ('D3 i j k) where
|
||||||
|
runForwards _ (S3D' y) = return $ S3D' (fmap tanh y)
|
||||||
|
runBackards _ _ (S3D' y) (S3D' dEdy) = return (Tanh, S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy))
|
||||||
|
|
||||||
|
tanh' :: (Floating a) => a -> a
|
||||||
|
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
251
test/Test/Grenade/Layers/Convolution.hs
Normal file
251
test/Test/Grenade/Layers/Convolution.hs
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||||
|
module Test.Grenade.Layers.Convolution where
|
||||||
|
|
||||||
|
import Control.Monad.Identity
|
||||||
|
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Core.Vector as Grenade
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Layers.Convolution
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||||
|
|
||||||
|
import Test.QuickCheck hiding ((><))
|
||||||
|
|
||||||
|
prop_im2col_no_stride = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
expected = (6><4)
|
||||||
|
[ 1.0, 2.0, 5.0, 6.0
|
||||||
|
, 2.0, 3.0, 6.0, 7.0
|
||||||
|
, 3.0, 4.0, 7.0, 8.0
|
||||||
|
, 5.0, 6.0, 9.0, 10.0
|
||||||
|
, 6.0, 7.0, 10.0, 11.0
|
||||||
|
, 7.0, 8.0, 11.0, 12.0 ]
|
||||||
|
out = im2col 2 2 1 1 input
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
prop_im2col_stride = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
expected = (4><4)
|
||||||
|
[ 1.0, 2.0, 5.0, 6.0
|
||||||
|
, 3.0, 4.0, 7.0, 8.0
|
||||||
|
, 5.0, 6.0, 9.0, 10.0
|
||||||
|
, 7.0, 8.0, 11.0, 12.0 ]
|
||||||
|
out = im2col 2 2 1 2 input
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
prop_im2col_other = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
expected = (2><6)
|
||||||
|
[ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0
|
||||||
|
, 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ]
|
||||||
|
out = im2col 3 2 1 2 input
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
-- If there's no overlap (stride is the same size as the kernel)
|
||||||
|
-- then col2im . im2col should be symmetric.
|
||||||
|
prop_im2col_sym_on_same_stride = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
|
||||||
|
in input === out
|
||||||
|
|
||||||
|
-- If there is an overlap, then the gradient passed back should be
|
||||||
|
-- the sum of the gradients across the filters.
|
||||||
|
prop_im2col_col2im_additive = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 1.0, 1.0, 1.0
|
||||||
|
, 1.0, 1.0, 1.0, 1.0
|
||||||
|
, 1.0, 1.0, 1.0, 1.0 ]
|
||||||
|
expected = (3><4)
|
||||||
|
[ 1.0, 2.0, 2.0, 1.0
|
||||||
|
, 2.0, 4.0, 4.0, 2.0
|
||||||
|
, 1.0, 2.0, 2.0, 1.0 ]
|
||||||
|
out = col2im 2 2 1 1 3 4 . im2col 2 2 1 1 $ input
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
prop_simple_conv_forwards = once $
|
||||||
|
-- Create a convolution kernel with 4 filters.
|
||||||
|
-- [ 1, 0 [ 0, 1 [ 0, 1 [ 0, 0
|
||||||
|
-- , 0,-1 ] ,-1, 0 ] , 1, 0 ] ,-1,-1 ]
|
||||||
|
let myKernel = (HStatic.matrix
|
||||||
|
[ 1.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 1.0, 1.0, 0.0
|
||||||
|
, 0.0, -1.0, 1.0, -1.0
|
||||||
|
,-1.0, 0.0, 0.0, -1.0 ] :: HStatic.L 4 4)
|
||||||
|
zeroKernel = (HStatic.matrix
|
||||||
|
[ 0.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4)
|
||||||
|
--expectedKernel = (HStatic.matrix
|
||||||
|
-- [ 0.0, 0.0, 0.0, -2.0
|
||||||
|
-- ,-2.0, 1.0, 1.0, -5.0
|
||||||
|
-- ,-3.0, -1.0, 1.0, -5.0
|
||||||
|
-- ,-5.0, 0.0, 0.0, -7.0 ] :: HStatic.L 4 4)
|
||||||
|
|
||||||
|
convLayer = Convolution myKernel zeroKernel :: Convolution 1 4 2 2 1 1
|
||||||
|
|
||||||
|
input = S2D' (HStatic.matrix
|
||||||
|
[ 1.0, 2.0, 5.0
|
||||||
|
, 3.0, 4.0, 6.0] :: HStatic.L 2 3)
|
||||||
|
|
||||||
|
expect = ([(HStatic.matrix
|
||||||
|
[ -3.0 , -4.0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ -1.0 , 1.0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 5.0 , 9.0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ -7.0 , -10.0 ] :: HStatic.L 1 2)]) :: [HStatic.L 1 2]
|
||||||
|
out = runIdentity $ runForwards convLayer input :: S' ('D3 1 2 4)
|
||||||
|
|
||||||
|
grad = S3D' ( mkVector
|
||||||
|
[(HStatic.matrix
|
||||||
|
[ 1 , 0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 0 , 1 ] :: HStatic.L 1 2)] ) :: S' ('D3 1 2 4)
|
||||||
|
|
||||||
|
expectBack = (HStatic.matrix
|
||||||
|
[ 1.0, 0.0, 0.0
|
||||||
|
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||||
|
(nc, inX) = runIdentity $ runBackards 1 convLayer input grad :: ( Convolution 1 4 2 2 1 1 , S' ('D2 2 3))
|
||||||
|
|
||||||
|
in case (out, inX, nc) of
|
||||||
|
(S3D' out' , S2D' inX', Convolution _ _)
|
||||||
|
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||||
|
.&&. ((HStatic.extract expectBack) === (HStatic.extract inX'))
|
||||||
|
-- Temporarily disabled, as l2 adjustment puts in off 5%
|
||||||
|
-- .&&. HStatic.extract expectedKernel === HStatic.extract kernel'
|
||||||
|
|
||||||
|
|
||||||
|
prop_vid2col_no_stride = once $
|
||||||
|
let input = [(3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
, (3><4)
|
||||||
|
[ 21.0, 22.0, 23.0, 24.0
|
||||||
|
, 25.0, 26.0, 27.0, 28.0
|
||||||
|
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||||
|
expected = (6><8)
|
||||||
|
[ 1.0, 2.0, 5.0, 6.0 , 21.0, 22.0, 25.0, 26.0
|
||||||
|
, 2.0, 3.0, 6.0, 7.0 , 22.0, 23.0, 26.0, 27.0
|
||||||
|
, 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0
|
||||||
|
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||||
|
, 6.0, 7.0, 10.0, 11.0 , 26.0, 27.0, 30.0, 31.0
|
||||||
|
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||||
|
out = vid2col 2 2 1 1 3 4 input
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
prop_vid2col_stride = once $
|
||||||
|
let input = [(3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
, (3><4)
|
||||||
|
[ 21.0, 22.0, 23.0, 24.0
|
||||||
|
, 25.0, 26.0, 27.0, 28.0
|
||||||
|
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||||
|
expected = (4><8)
|
||||||
|
[ 1.0, 2.0, 5.0, 6.0 , 21.0, 22.0, 25.0, 26.0
|
||||||
|
, 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0
|
||||||
|
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||||
|
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||||
|
out = vid2col 2 2 1 2 3 4 input
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
|
||||||
|
prop_vid2col_invert = once $
|
||||||
|
let input = [(3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
, (3><4)
|
||||||
|
[ 21.0, 22.0, 23.0, 24.0
|
||||||
|
, 25.0, 26.0, 27.0, 28.0
|
||||||
|
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||||
|
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
|
||||||
|
in input === out
|
||||||
|
|
||||||
|
-- This test show that 2D convs act the same
|
||||||
|
-- 3D convs with one layer
|
||||||
|
prop_single_conv_forwards = once $
|
||||||
|
-- Create a convolution kernel with 4 filters.
|
||||||
|
-- [ 1, 0 [ 0, 1 [ 0, 1 [ 0, 0
|
||||||
|
-- , 0,-1 ] ,-1, 0 ] , 1, 0 ] ,-1,-1 ]
|
||||||
|
let myKernel = (HStatic.matrix
|
||||||
|
[ 1.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 1.0, 1.0, 0.0
|
||||||
|
, 0.0, -1.0, 1.0, -1.0
|
||||||
|
,-1.0, 0.0, 0.0, -1.0 ] :: HStatic.L 4 4)
|
||||||
|
zeroKernel = (HStatic.matrix
|
||||||
|
[ 0.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4)
|
||||||
|
--expectedKernel = (HStatic.matrix
|
||||||
|
-- [ 0.0, 0.0, 0.0, -2.0
|
||||||
|
-- ,-2.0, 1.0, 1.0, -5.0
|
||||||
|
-- ,-3.0, -1.0, 1.0, -5.0
|
||||||
|
-- ,-5.0, 0.0, 0.0, -7.0 ] :: HStatic.L 4 4)
|
||||||
|
|
||||||
|
convLayer = Convolution myKernel zeroKernel :: Convolution 1 4 2 2 1 1
|
||||||
|
|
||||||
|
input = S3D' ( mkVector [HStatic.matrix
|
||||||
|
[ 1.0, 2.0, 5.0
|
||||||
|
, 3.0, 4.0, 6.0] :: HStatic.L 2 3] ) :: S' ('D3 2 3 1)
|
||||||
|
|
||||||
|
expect = ([(HStatic.matrix
|
||||||
|
[ -3.0 , -4.0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ -1.0 , 1.0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 5.0 , 9.0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ -7.0 , -10.0 ] :: HStatic.L 1 2)]) :: [HStatic.L 1 2]
|
||||||
|
out = runIdentity $ runForwards convLayer input :: S' ('D3 1 2 4)
|
||||||
|
|
||||||
|
grad = S3D' ( mkVector
|
||||||
|
[(HStatic.matrix
|
||||||
|
[ 1 , 0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||||
|
,(HStatic.matrix
|
||||||
|
[ 0 , 1 ] :: HStatic.L 1 2)] ) :: S' ('D3 1 2 4)
|
||||||
|
|
||||||
|
expectBack = (HStatic.matrix
|
||||||
|
[ 1.0, 0.0, 0.0
|
||||||
|
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||||
|
(nc, inX) = runIdentity $ runBackards 1 convLayer input grad :: ( Convolution 1 4 2 2 1 1 , S' ('D3 2 3 1))
|
||||||
|
|
||||||
|
in case (out, inX, nc) of
|
||||||
|
(S3D' out' , S3D' inX', Convolution _ _)
|
||||||
|
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||||
|
.&&. ([HStatic.extract expectBack] === (HStatic.extract <$> vecToList inX'))
|
||||||
|
-- .&&. HStatic.extract expectedKernel === HStatic.extract kernel'
|
||||||
|
|
||||||
|
return []
|
||||||
|
tests :: IO Bool
|
||||||
|
tests = $quickCheckAll
|
56
test/Test/Grenade/Layers/Pooling.hs
Normal file
56
test/Test/Grenade/Layers/Pooling.hs
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||||
|
module Test.Grenade.Layers.Pooling where
|
||||||
|
|
||||||
|
import Grenade.Layers.Pooling
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||||
|
|
||||||
|
import Test.QuickCheck hiding ((><))
|
||||||
|
|
||||||
|
prop_pool = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
expected = (2><3)
|
||||||
|
[ 6.0, 7.0, 8.0
|
||||||
|
, 10.0, 11.0, 12.0 ]
|
||||||
|
out = poolForward 2 2 1 1 2 3 input
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
prop_pool_backwards = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
|
grads = (2><3)
|
||||||
|
[ -6.0, -7.0, -8.0
|
||||||
|
, -10.0, -11.0, -12.0 ]
|
||||||
|
expected = (3><4)
|
||||||
|
[ 0.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, -6.0, -7.0, -8.0
|
||||||
|
, 0.0,-10.0,-11.0,-12.0 ]
|
||||||
|
out = poolBackward 2 2 1 1 input grads
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
prop_pool_backwards_additive = once $
|
||||||
|
let input = (3><4)
|
||||||
|
[ 4.0, 2.0, 3.0, 4.0
|
||||||
|
, 0.0, 0.0, 7.0, 8.0
|
||||||
|
, 9.0, 0.0, 0.0, 0.0 ]
|
||||||
|
grads = (2><3)
|
||||||
|
[ -6.0, -7.0, -8.0
|
||||||
|
, -10.0, -11.0, -12.0 ]
|
||||||
|
expected = (3><4)
|
||||||
|
[-6.0, 0.0, 0.0, 0.0
|
||||||
|
, 0.0, 0.0,-18.0,-20.0
|
||||||
|
,-10.0, 0.0, 0.0, 0.0 ]
|
||||||
|
out = poolBackward 2 2 1 1 input grads
|
||||||
|
in expected === out
|
||||||
|
|
||||||
|
return []
|
||||||
|
tests :: IO Bool
|
||||||
|
tests = $quickCheckAll
|
11
test/test.hs
Normal file
11
test/test.hs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import Disorder.Core.Main
|
||||||
|
|
||||||
|
import qualified Test.Grenade.Layers.Pooling as Test.Grenade.Layers.Pooling
|
||||||
|
import qualified Test.Grenade.Layers.Convolution as Test.Grenade.Layers.Convolution
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main =
|
||||||
|
disorderMain [
|
||||||
|
Test.Grenade.Layers.Pooling.tests
|
||||||
|
, Test.Grenade.Layers.Convolution.tests
|
||||||
|
]
|
Loading…
Reference in New Issue
Block a user