mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-10-26 09:22:23 +03:00
Remove primes on shape instantiations
Add singletons for Shape and remove hacks on recurrent nets Add Recurrent Nets
This commit is contained in:
parent
88021fbde7
commit
ac0e4b22c8
@ -52,7 +52,7 @@ Usage
|
|||||||
To perform back propagation, one can call the eponymous function
|
To perform back propagation, one can call the eponymous function
|
||||||
```haskell
|
```haskell
|
||||||
backPropagate :: forall input target shapes layers. (Head shapes ~ input, Last shapes ~ target)
|
backPropagate :: forall input target shapes layers. (Head shapes ~ input, Last shapes ~ target)
|
||||||
=> Network layers shapes -> S' input -> S' target -> Gradients layers
|
=> Network layers shapes -> S input -> S target -> Gradients layers
|
||||||
```
|
```
|
||||||
which takes a network, appropriate input and target data, and returns the
|
which takes a network, appropriate input and target data, and returns the
|
||||||
back propagated gradients for the network. The shapes of the gradients are
|
back propagated gradients for the network. The shapes of the gradients are
|
||||||
|
13
cbits/gradient_decent.c
Normal file
13
cbits/gradient_decent.c
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#include "gradient_decent.h"
|
||||||
|
|
||||||
|
void decend_cpu(int len, double rate, double momentum, double regulariser,
|
||||||
|
const double* weights,
|
||||||
|
const double* gradient,
|
||||||
|
const double* last,
|
||||||
|
double* outputWeights, double* outputMomentum) {
|
||||||
|
|
||||||
|
for (int i = 0; i <= len; i++) {
|
||||||
|
outputMomentum[i] = momentum * last[i] - rate * gradient[i];
|
||||||
|
outputWeights[i] = weights[i] + outputMomentum[i] - (rate * regulariser) * weights[i];
|
||||||
|
}
|
||||||
|
}
|
9
cbits/gradient_decent.h
Normal file
9
cbits/gradient_decent.h
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#include <stdio.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
void decend_cpu(int len, double rate, double momentum, double regulariser,
|
||||||
|
const double* weights,
|
||||||
|
const double* gradient,
|
||||||
|
const double* last,
|
||||||
|
double* outputWeights, double* outputMomentum);
|
||||||
|
|
@ -1,11 +1,10 @@
|
|||||||
#include "im2col.h"
|
#include "im2col.h"
|
||||||
|
|
||||||
void im2col_cpu(const double* data_im, int dataOffset, const int channels,
|
void im2col_cpu(const double* data_im, const int channels,
|
||||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||||
const int stride_h, const int stride_w,
|
const int stride_h, const int stride_w,
|
||||||
double* data_col) {
|
double* data_col) {
|
||||||
|
|
||||||
data_im += dataOffset;
|
|
||||||
const int channel_size = height * width;
|
const int channel_size = height * width;
|
||||||
|
|
||||||
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
|
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
|
||||||
@ -23,13 +22,12 @@ void im2col_cpu(const double* data_im, int dataOffset, const int channels,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
void col2im_cpu(const double* data_col, const int channels,
|
||||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||||
const int stride_h, const int stride_w,
|
const int stride_h, const int stride_w,
|
||||||
double* data_im) {
|
double* data_im) {
|
||||||
|
|
||||||
memset(data_im, 0, height * width * channels * sizeof(double));
|
memset(data_im, 0, height * width * channels * sizeof(double));
|
||||||
data_col += dataOffset;
|
|
||||||
|
|
||||||
const int channel_size = height * width;
|
const int channel_size = height * width;
|
||||||
|
|
||||||
@ -50,13 +48,11 @@ void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
|||||||
|
|
||||||
inline double max ( double a, double b ) { return a > b ? a : b; }
|
inline double max ( double a, double b ) { return a > b ? a : b; }
|
||||||
|
|
||||||
void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
|
void pool_forwards_cpu(const double* data_im, const int channels,
|
||||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||||
const int stride_h, const int stride_w,
|
const int stride_h, const int stride_w,
|
||||||
double* data_pooled) {
|
double* data_pooled) {
|
||||||
|
|
||||||
data_im += dataOffset;
|
|
||||||
|
|
||||||
const int channel_size = height * width;
|
const int channel_size = height * width;
|
||||||
|
|
||||||
for (int channel = 0; channel < channels; channel++) {
|
for (int channel = 0; channel < channels; channel++) {
|
||||||
@ -89,14 +85,11 @@ void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void pool_backwards_cpu(const double* data_im, int data_im_offset,
|
void pool_backwards_cpu(const double* data_im, const double* data_pooled,
|
||||||
const double* data_pooled, int data_pooled_offset,
|
|
||||||
const int channels, const int height, const int width, const int kernel_h,
|
const int channels, const int height, const int width, const int kernel_h,
|
||||||
const int kernel_w, const int stride_h, const int stride_w,
|
const int kernel_w, const int stride_h, const int stride_w,
|
||||||
double* data_backgrad ) {
|
double* data_backgrad ) {
|
||||||
|
|
||||||
data_im += data_im_offset;
|
|
||||||
data_pooled += data_pooled_offset;
|
|
||||||
memset(data_backgrad, 0, height * width * channels * sizeof(double));
|
memset(data_backgrad, 0, height * width * channels * sizeof(double));
|
||||||
|
|
||||||
const int channel_size = height * width;
|
const int channel_size = height * width;
|
||||||
|
@ -2,23 +2,22 @@
|
|||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
void im2col_cpu(const double* data_im, int dataOffset, const int channels,
|
void im2col_cpu(const double* data_im, const int channels,
|
||||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||||
const int stride_h, const int stride_w,
|
const int stride_h, const int stride_w,
|
||||||
double* data_col);
|
double* data_col);
|
||||||
|
|
||||||
void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
void col2im_cpu(const double* data_col, const int channels,
|
||||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||||
const int stride_h, const int stride_w,
|
const int stride_h, const int stride_w,
|
||||||
double* data_im);
|
double* data_im);
|
||||||
|
|
||||||
void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
|
void pool_forwards_cpu(const double* data_im, const int channels,
|
||||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||||
const int stride_h, const int stride_w,
|
const int stride_h, const int stride_w,
|
||||||
double* data_pooled);
|
double* data_pooled);
|
||||||
|
|
||||||
void pool_backwards_cpu(const double* data_im, int data_im_offset,
|
void pool_backwards_cpu(const double* data_im, const double* data_pooled,
|
||||||
const double* data_pooled, int data_pooled_offset,
|
|
||||||
const int channels, const int height, const int width, const int kernel_h,
|
const int channels, const int height, const int width, const int kernel_h,
|
||||||
const int kernel_w, const int stride_h, const int stride_w,
|
const int kernel_w, const int stride_h, const int stride_w,
|
||||||
double* data_backgrad );
|
double* data_backgrad );
|
||||||
|
@ -19,6 +19,8 @@ library
|
|||||||
base >= 4.8 && < 5
|
base >= 4.8 && < 5
|
||||||
, bytestring == 0.10.*
|
, bytestring == 0.10.*
|
||||||
, async
|
, async
|
||||||
|
, containers
|
||||||
|
, deepseq
|
||||||
, either == 4.4.*
|
, either == 4.4.*
|
||||||
, exceptions == 0.8.*
|
, exceptions == 0.8.*
|
||||||
, hmatrix
|
, hmatrix
|
||||||
@ -26,9 +28,11 @@ library
|
|||||||
, mtl >= 2.2.1 && < 2.3
|
, mtl >= 2.2.1 && < 2.3
|
||||||
, parallel == 3.2.*
|
, parallel == 3.2.*
|
||||||
, primitive
|
, primitive
|
||||||
|
, reflection
|
||||||
, text == 1.2.*
|
, text == 1.2.*
|
||||||
, transformers
|
, transformers
|
||||||
, singletons
|
, singletons
|
||||||
|
, vector
|
||||||
|
|
||||||
ghc-options:
|
ghc-options:
|
||||||
-Wall
|
-Wall
|
||||||
@ -55,9 +59,23 @@ library
|
|||||||
|
|
||||||
Grenade.Layers.Internal.Convolution
|
Grenade.Layers.Internal.Convolution
|
||||||
Grenade.Layers.Internal.Pooling
|
Grenade.Layers.Internal.Pooling
|
||||||
|
Grenade.Layers.Internal.Update
|
||||||
|
|
||||||
|
Grenade.Recurrent
|
||||||
|
|
||||||
|
Grenade.Recurrent.Core.Network
|
||||||
|
Grenade.Recurrent.Core.Runner
|
||||||
|
|
||||||
|
Grenade.Recurrent.Layers.BasicRecurrent
|
||||||
|
Grenade.Recurrent.Layers.LSTM
|
||||||
|
Grenade.Recurrent.Layers.Trivial
|
||||||
|
|
||||||
|
Grenade.Utils.OneHot
|
||||||
|
|
||||||
includes: cbits/im2col.h
|
includes: cbits/im2col.h
|
||||||
|
cbits/gradient_decent.h
|
||||||
c-sources: cbits/im2col.c
|
c-sources: cbits/im2col.c
|
||||||
|
cbits/gradient_decent.c
|
||||||
|
|
||||||
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
|
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
|
||||||
|
|
||||||
@ -90,6 +108,40 @@ executable mnist
|
|||||||
, transformers
|
, transformers
|
||||||
, singletons
|
, singletons
|
||||||
, MonadRandom
|
, MonadRandom
|
||||||
|
, vector
|
||||||
|
|
||||||
|
executable recurrent
|
||||||
|
ghc-options: -Wall -threaded -O2
|
||||||
|
main-is: main/recurrent.hs
|
||||||
|
build-depends: base
|
||||||
|
, grenade
|
||||||
|
, attoparsec
|
||||||
|
, either
|
||||||
|
, optparse-applicative == 0.12.*
|
||||||
|
, text == 1.2.*
|
||||||
|
, mtl >= 2.2.1 && < 2.3
|
||||||
|
, hmatrix >= 0.18 && < 0.19
|
||||||
|
, transformers
|
||||||
|
, singletons
|
||||||
|
, MonadRandom
|
||||||
|
|
||||||
|
|
||||||
|
executable shakespeare
|
||||||
|
ghc-options: -Wall -threaded -O2
|
||||||
|
main-is: main/shakespeare.hs
|
||||||
|
build-depends: base
|
||||||
|
, grenade
|
||||||
|
, attoparsec
|
||||||
|
, either
|
||||||
|
, optparse-applicative == 0.12.*
|
||||||
|
, text == 1.2.*
|
||||||
|
, mtl >= 2.2.1 && < 2.3
|
||||||
|
, hmatrix >= 0.18 && < 0.19
|
||||||
|
, transformers
|
||||||
|
, singletons
|
||||||
|
, vector
|
||||||
|
, MonadRandom
|
||||||
|
, containers
|
||||||
|
|
||||||
|
|
||||||
test-suite test
|
test-suite test
|
||||||
@ -117,6 +169,8 @@ test-suite test
|
|||||||
, quickcheck-instances == 0.3.*
|
, quickcheck-instances == 0.3.*
|
||||||
, MonadRandom
|
, MonadRandom
|
||||||
, random
|
, random
|
||||||
|
, ad
|
||||||
|
, reflection
|
||||||
|
|
||||||
|
|
||||||
benchmark bench
|
benchmark bench
|
||||||
@ -135,3 +189,20 @@ benchmark bench
|
|||||||
, criterion == 1.1.*
|
, criterion == 1.1.*
|
||||||
, grenade
|
, grenade
|
||||||
, hmatrix
|
, hmatrix
|
||||||
|
|
||||||
|
benchmark bench-lstm
|
||||||
|
type: exitcode-stdio-1.0
|
||||||
|
|
||||||
|
main-is: bench-lstm.hs
|
||||||
|
|
||||||
|
ghc-options: -Wall -threaded -O2
|
||||||
|
|
||||||
|
hs-source-dirs:
|
||||||
|
bench
|
||||||
|
|
||||||
|
build-depends:
|
||||||
|
base >= 3 && < 5
|
||||||
|
, bytestring == 0.10.*
|
||||||
|
, criterion == 1.1.*
|
||||||
|
, grenade
|
||||||
|
, hmatrix
|
||||||
|
6
mafia
6
mafia
@ -1,5 +1,7 @@
|
|||||||
#!/bin/sh -eu
|
#!/bin/sh -eu
|
||||||
|
|
||||||
|
: ${MAFIA_HOME:=$HOME/.mafia}
|
||||||
|
|
||||||
fetch_latest () {
|
fetch_latest () {
|
||||||
if [ -z ${MAFIA_TEST_MODE+x} ]; then
|
if [ -z ${MAFIA_TEST_MODE+x} ]; then
|
||||||
TZ=$(date +"%T")
|
TZ=$(date +"%T")
|
||||||
@ -55,7 +57,7 @@ exec_mafia () {
|
|||||||
# If we can't find the mafia version, then we need to upgrade the script.
|
# If we can't find the mafia version, then we need to upgrade the script.
|
||||||
run_upgrade
|
run_upgrade
|
||||||
else
|
else
|
||||||
MAFIA_BIN=$HOME/.ambiata/mafia/bin
|
MAFIA_BIN=$MAFIA_HOME/bin
|
||||||
MAFIA_FILE=mafia-$MAFIA_VERSION
|
MAFIA_FILE=mafia-$MAFIA_VERSION
|
||||||
MAFIA_PATH=$MAFIA_BIN/$MAFIA_FILE
|
MAFIA_PATH=$MAFIA_BIN/$MAFIA_FILE
|
||||||
|
|
||||||
@ -118,4 +120,4 @@ case "$MODE" in
|
|||||||
upgrade) shift; run_upgrade "$@" ;;
|
upgrade) shift; run_upgrade "$@" ;;
|
||||||
*) exec_mafia "$@"
|
*) exec_mafia "$@"
|
||||||
esac
|
esac
|
||||||
# Version: a1b39ee8ac1969ed2e891b9062d079be75863e99
|
# Version: 3044e63eb472fb9e16926d4ab2ca9dd9e255829c
|
||||||
|
@ -4,10 +4,10 @@
|
|||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TupleSections #-}
|
{-# LANGUAGE TupleSections #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
|
||||||
|
|
||||||
import Control.Monad
|
import Control.Monad
|
||||||
import Control.Monad.Random
|
import Control.Monad.Random
|
||||||
|
import Data.List ( foldl' )
|
||||||
|
|
||||||
import GHC.TypeLits
|
import GHC.TypeLits
|
||||||
|
|
||||||
import qualified Numeric.LinearAlgebra.Static as SA
|
import qualified Numeric.LinearAlgebra.Static as SA
|
||||||
@ -34,18 +34,18 @@ netTest :: MonadRandom m => LearningParameters -> Int -> m String
|
|||||||
netTest rate n = do
|
netTest rate n = do
|
||||||
inps <- replicateM n $ do
|
inps <- replicateM n $ do
|
||||||
s <- getRandom
|
s <- getRandom
|
||||||
return $ S1D' $ SA.randomVector s SA.Uniform * 2 - 1
|
return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
|
||||||
let outs = flip map inps $ \(S1D' v) ->
|
let outs = flip map inps $ \(S1D v) ->
|
||||||
if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
|
if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
|
||||||
then S1D' $ fromRational 1
|
then S1D $ fromRational 1
|
||||||
else S1D' $ fromRational 0
|
else S1D $ fromRational 0
|
||||||
net0 <- randomNet
|
net0 <- randomNet
|
||||||
|
|
||||||
let trained = foldl trainEach net0 (zip inps outs)
|
let trained = foldl' trainEach net0 (zip inps outs)
|
||||||
let testIns = [ [ (x,y) | x <- [0..50] ]
|
let testIns = [ [ (x,y) | x <- [0..50] ]
|
||||||
| y <- [0..20] ]
|
| y <- [0..20] ]
|
||||||
|
|
||||||
let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D' $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
|
let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
|
||||||
return $ unlines outMat
|
return $ unlines outMat
|
||||||
|
|
||||||
where
|
where
|
||||||
@ -59,8 +59,8 @@ netTest rate n = do
|
|||||||
| n' <= 0.8 = '='
|
| n' <= 0.8 = '='
|
||||||
| otherwise = '#'
|
| otherwise = '#'
|
||||||
|
|
||||||
normx :: S' ('D1 1) -> Double
|
normx :: S ('D1 1) -> Double
|
||||||
normx (S1D' r) = SA.mean r
|
normx (S1D r) = SA.mean r
|
||||||
|
|
||||||
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
|
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
|
||||||
|
|
||||||
|
@ -5,23 +5,24 @@
|
|||||||
{-# LANGUAGE TupleSections #-}
|
{-# LANGUAGE TupleSections #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
|
||||||
import Control.Applicative
|
import Control.Applicative
|
||||||
import Control.Monad
|
import Control.Monad
|
||||||
import Control.Monad.Random
|
import Control.Monad.Random
|
||||||
import Control.Monad.Trans.Class
|
|
||||||
import Control.Monad.Trans.Except
|
import Control.Monad.Trans.Except
|
||||||
|
|
||||||
import qualified Data.Attoparsec.Text as A
|
import qualified Data.Attoparsec.Text as A
|
||||||
|
import Data.List ( foldl' )
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import qualified Data.Text.IO as T
|
import qualified Data.Text.IO as T
|
||||||
|
import qualified Data.Vector.Storable as V
|
||||||
|
|
||||||
import Numeric.LinearAlgebra (maxIndex)
|
import Numeric.LinearAlgebra ( maxIndex )
|
||||||
import qualified Numeric.LinearAlgebra.Static as SA
|
import qualified Numeric.LinearAlgebra.Static as SA
|
||||||
|
|
||||||
import Options.Applicative
|
import Options.Applicative
|
||||||
|
|
||||||
import Grenade
|
import Grenade
|
||||||
|
import Grenade.Utils.OneHot
|
||||||
|
|
||||||
-- The definition of our convolutional neural network.
|
-- The definition of our convolutional neural network.
|
||||||
-- In the type signature, we have a type level list of shapes which are passed between the layers.
|
-- In the type signature, we have a type level list of shapes which are passed between the layers.
|
||||||
@ -49,9 +50,9 @@ convTest iterations trainFile validateFile rate = do
|
|||||||
trainEach rate' !network (i, o) = train rate' network i o
|
trainEach rate' !network (i, o) = train rate' network i o
|
||||||
|
|
||||||
runIteration trainRows validateRows net i = do
|
runIteration trainRows validateRows net i = do
|
||||||
let trained' = foldl (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
|
let trained' = foldl' (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
|
||||||
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
|
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
|
||||||
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
let res' = fmap (\(S1D label, S1D prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||||
print trained'
|
print trained'
|
||||||
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
||||||
return trained'
|
return trained'
|
||||||
@ -61,7 +62,7 @@ data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
|
|||||||
mnist' :: Parser MnistOpts
|
mnist' :: Parser MnistOpts
|
||||||
mnist' = MnistOpts <$> argument str (metavar "TRAIN")
|
mnist' = MnistOpts <$> argument str (metavar "TRAIN")
|
||||||
<*> argument str (metavar "VALIDATE")
|
<*> argument str (metavar "VALIDATE")
|
||||||
<*> option auto (long "iterations" <> short 'i' <> value 10)
|
<*> option auto (long "iterations" <> short 'i' <> value 15)
|
||||||
<*> (LearningParameters
|
<*> (LearningParameters
|
||||||
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||||
<*> option auto (long "momentum" <> value 0.9)
|
<*> option auto (long "momentum" <> value 0.9)
|
||||||
@ -78,14 +79,14 @@ main = do
|
|||||||
Right () -> pure ()
|
Right () -> pure ()
|
||||||
Left err -> putStrLn err
|
Left err -> putStrLn err
|
||||||
|
|
||||||
readMNIST :: FilePath -> ExceptT String IO [(S' ('D2 28 28), S' ('D1 10))]
|
readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
|
||||||
readMNIST mnist = ExceptT $ do
|
readMNIST mnist = ExceptT $ do
|
||||||
mnistdata <- T.readFile mnist
|
mnistdata <- T.readFile mnist
|
||||||
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
|
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
|
||||||
|
|
||||||
parseMNIST :: A.Parser (S' ('D2 28 28), S' ('D1 10))
|
parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
|
||||||
parseMNIST = do
|
parseMNIST = do
|
||||||
lab <- A.decimal
|
Just lab <- oneHot <$> A.decimal
|
||||||
pixels <- many (A.char ',' >> A.double)
|
pixels <- many (A.char ',' >> A.double)
|
||||||
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
|
image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
|
||||||
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
|
return (image, lab)
|
||||||
|
87
main/recurrent.hs
Normal file
87
main/recurrent.hs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
|
import Control.Monad ( foldM )
|
||||||
|
import Control.Monad.Random ( MonadRandom, getRandomR )
|
||||||
|
|
||||||
|
import Data.List ( cycle, unfoldr )
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as SA
|
||||||
|
|
||||||
|
import Options.Applicative
|
||||||
|
|
||||||
|
import Grenade
|
||||||
|
import Grenade.Recurrent
|
||||||
|
|
||||||
|
-- The defininition for our simple recurrent network.
|
||||||
|
-- This file just trains a network to generate a repeating sequence
|
||||||
|
-- of 0 0 1.
|
||||||
|
--
|
||||||
|
-- The F and R types are Tagging types to ensure that the runner and
|
||||||
|
-- creation function know how to treat the layers.
|
||||||
|
type F = FeedForward
|
||||||
|
type R = Recurrent
|
||||||
|
|
||||||
|
type RecNet = RecurrentNetwork '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial]
|
||||||
|
'[ 'D1 1, 'D1 4, 'D1 1, 'D1 1 ]
|
||||||
|
|
||||||
|
type RecInput = RecurrentInputs '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial]
|
||||||
|
|
||||||
|
randomNet :: MonadRandom m => m (RecNet, RecInput)
|
||||||
|
randomNet = randomRecurrent
|
||||||
|
|
||||||
|
netTest :: MonadRandom m => RecNet -> RecInput -> LearningParameters -> Int -> m (RecNet, RecInput)
|
||||||
|
netTest net0 i0 rate iterations =
|
||||||
|
foldM trainIteration (net0,i0) [1..iterations]
|
||||||
|
where
|
||||||
|
trainingCycle = cycle [c 0, c 0, c 1]
|
||||||
|
|
||||||
|
trainIteration (net, io) _ = do
|
||||||
|
dropping <- getRandomR (0, 2)
|
||||||
|
count <- getRandomR (5, 30)
|
||||||
|
let t = drop dropping trainingCycle
|
||||||
|
let example = ((,Nothing) <$> take count t) ++ [(t !! count, Just $ t !! (count + 1))]
|
||||||
|
return $ trainEach net io example
|
||||||
|
|
||||||
|
trainEach !nt !io !ex = trainRecurrent rate nt io ex
|
||||||
|
|
||||||
|
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
|
||||||
|
|
||||||
|
feedForward' :: Parser FeedForwardOpts
|
||||||
|
feedForward' = FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 20000)
|
||||||
|
<*> (LearningParameters
|
||||||
|
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||||
|
<*> option auto (long "momentum" <> value 0.9)
|
||||||
|
<*> option auto (long "l2" <> value 0.0005)
|
||||||
|
)
|
||||||
|
|
||||||
|
generateRecurrent :: RecNet -> RecInput -> S ('D1 1) -> [Int]
|
||||||
|
generateRecurrent n s i =
|
||||||
|
unfoldr go (s, i)
|
||||||
|
where
|
||||||
|
go (x, y) =
|
||||||
|
do let (ns, o) = runRecurrent n x y
|
||||||
|
o' = heat o
|
||||||
|
Just (o', (ns, fromIntegral o'))
|
||||||
|
|
||||||
|
heat :: S ('D1 1) -> Int
|
||||||
|
heat x = case x of
|
||||||
|
(S1D v) -> round (SA.mean v)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do
|
||||||
|
FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm)
|
||||||
|
putStrLn "Training network..."
|
||||||
|
|
||||||
|
(net0, i0) <- randomNet
|
||||||
|
(trained, bestInput) <- netTest net0 i0 rate examples
|
||||||
|
|
||||||
|
let results = generateRecurrent trained bestInput (c 1)
|
||||||
|
|
||||||
|
print . take 50 . drop 100 $ results
|
||||||
|
|
||||||
|
c :: Double -> S ('D1 1)
|
||||||
|
c = S1D . SA.konst
|
156
main/shakespeare.hs
Normal file
156
main/shakespeare.hs
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
|
||||||
|
import Control.Monad.Random
|
||||||
|
import Control.Monad.Trans.Except
|
||||||
|
|
||||||
|
import Data.Char ( isUpper, toUpper, toLower )
|
||||||
|
import Data.List ( unfoldr, foldl' )
|
||||||
|
import Data.Maybe ( fromMaybe )
|
||||||
|
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
import Data.Vector ( Vector )
|
||||||
|
|
||||||
|
import qualified Data.Map as M
|
||||||
|
import Data.Proxy ( Proxy (..) )
|
||||||
|
|
||||||
|
|
||||||
|
import Data.Singletons.Prelude
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra.Static ( konst )
|
||||||
|
|
||||||
|
import Options.Applicative
|
||||||
|
|
||||||
|
import Grenade
|
||||||
|
import Grenade.Recurrent
|
||||||
|
import Grenade.Utils.OneHot
|
||||||
|
|
||||||
|
-- The defininition for our natural language recurrent network.
|
||||||
|
-- This network is able to learn and generate simple words in
|
||||||
|
-- about an hour.
|
||||||
|
--
|
||||||
|
-- This is a first class recurrent net, although it's similar to
|
||||||
|
-- an unrolled graph.
|
||||||
|
--
|
||||||
|
-- The F and R types are tagging types to ensure that the runner and
|
||||||
|
-- creation function know how to treat the layers.
|
||||||
|
--
|
||||||
|
-- As an example, here's a short sequence generated.
|
||||||
|
--
|
||||||
|
-- > the see and and the sir, and and the make and the make and go the make and go the make and the
|
||||||
|
--
|
||||||
|
type F = FeedForward
|
||||||
|
type R = Recurrent
|
||||||
|
|
||||||
|
-- The definition of our network
|
||||||
|
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
|
||||||
|
'[ 'D1 40, 'D1 40, 'D1 40, 'D1 40 ]
|
||||||
|
|
||||||
|
-- The definition of the "sideways" input, which the network if fed recurrently.
|
||||||
|
type Shakespearian = RecurrentInputs '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
|
||||||
|
|
||||||
|
randomNet :: MonadRandom m => m (Shakespeare, Shakespearian)
|
||||||
|
randomNet = randomRecurrent
|
||||||
|
|
||||||
|
-- | Load the data files and prepare a map of characters to a compressed int representation.
|
||||||
|
loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Vector Char)
|
||||||
|
loadShakespeare path = do
|
||||||
|
contents <- lift $ readFile path
|
||||||
|
let annotated = annotateCapitals contents
|
||||||
|
(m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated
|
||||||
|
hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
|
||||||
|
return (V.fromList hot, m, cs)
|
||||||
|
|
||||||
|
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
|
||||||
|
trainSlice !rate !net !recIns input offset size =
|
||||||
|
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
|
||||||
|
in case reverse e of
|
||||||
|
(o : l : xs) ->
|
||||||
|
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
|
||||||
|
in trainRecurrent rate net recIns examples
|
||||||
|
_ -> error "Not enough input"
|
||||||
|
where
|
||||||
|
x = fromMaybe (error "Hot variable didn't fit.")
|
||||||
|
|
||||||
|
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
|
||||||
|
runShakespeare ShakespeareOpts {..} = do
|
||||||
|
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
|
||||||
|
(net0, i0) <- lift randomNet
|
||||||
|
lift $ foldM_ (\(!net, !io) size -> do
|
||||||
|
xs <- take (iterations `div` 15) <$> getRandomRs (0, length shakespeare - size - 1)
|
||||||
|
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
|
||||||
|
let results = take 100 $ generateParagraph trained bestInput oneHotMap oneHotDictionary ( S1D $ konst 0)
|
||||||
|
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
|
||||||
|
putStrLn (unAnnotateCapitals results)
|
||||||
|
return (trained, bestInput)
|
||||||
|
) (net0, i0) [10,10,15,15,20,20,25,25,30,30,35,35,40,40,50 :: Int]
|
||||||
|
|
||||||
|
generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes ~ 'D1 n, KnownNat n, Ord a)
|
||||||
|
=> RecurrentNetwork layers shapes
|
||||||
|
-> RecurrentInputs layers
|
||||||
|
-> M.Map a Int
|
||||||
|
-> Vector a
|
||||||
|
-> S ('D1 n)
|
||||||
|
-> [a]
|
||||||
|
generateParagraph n s hotmap hotdict i =
|
||||||
|
unfoldr go (s, i)
|
||||||
|
where
|
||||||
|
go (x, y) =
|
||||||
|
do let (ns, o) = runRecurrent n x y
|
||||||
|
un <- unHot hotdict o
|
||||||
|
re <- makeHot hotmap un
|
||||||
|
Just (un, (ns, re))
|
||||||
|
|
||||||
|
data ShakespeareOpts = ShakespeareOpts {
|
||||||
|
trainingFile :: FilePath
|
||||||
|
, iterations :: Int
|
||||||
|
, rate :: LearningParameters
|
||||||
|
}
|
||||||
|
|
||||||
|
shakespeare' :: Parser ShakespeareOpts
|
||||||
|
shakespeare' = ShakespeareOpts <$> argument str (metavar "TRAIN")
|
||||||
|
<*> option auto (long "examples" <> short 'e' <> value 1000000)
|
||||||
|
<*> (LearningParameters
|
||||||
|
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||||
|
<*> option auto (long "momentum" <> value 0.95)
|
||||||
|
<*> option auto (long "l2" <> value 0.000001)
|
||||||
|
)
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do
|
||||||
|
shopts <- execParser (info (shakespeare' <**> helper) idm)
|
||||||
|
res <- runExceptT $ runShakespeare shopts
|
||||||
|
case res of
|
||||||
|
Right () -> pure ()
|
||||||
|
Left err -> putStrLn err
|
||||||
|
|
||||||
|
|
||||||
|
-- Replace capitals with an annotation and the lower case letter
|
||||||
|
-- http://fastml.com/one-weird-trick-for-training-char-rnns/
|
||||||
|
annotateCapitals :: String -> String
|
||||||
|
annotateCapitals (x : rest)
|
||||||
|
| isUpper x
|
||||||
|
= '^' : toLower x : annotateCapitals rest
|
||||||
|
| otherwise
|
||||||
|
= x : annotateCapitals rest
|
||||||
|
annotateCapitals []
|
||||||
|
= []
|
||||||
|
|
||||||
|
unAnnotateCapitals :: String -> String
|
||||||
|
unAnnotateCapitals ('^' : x : rest)
|
||||||
|
= toUpper x : unAnnotateCapitals rest
|
||||||
|
unAnnotateCapitals (x : rest)
|
||||||
|
= x : unAnnotateCapitals rest
|
||||||
|
unAnnotateCapitals []
|
||||||
|
= []
|
||||||
|
|
||||||
|
-- | Tag the 'Nothing' value of a 'Maybe'
|
||||||
|
note :: a -> Maybe b -> Either a b
|
||||||
|
note a = maybe (Left a) Right
|
@ -1,15 +1,21 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE KindSignatures #-}
|
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE PolyKinds #-}
|
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-|
|
||||||
|
Module : Grenade.Core.Network
|
||||||
|
Description : Core definition a simple neural etwork
|
||||||
|
Copyright : (c) Huw Campbell, 2016-2017
|
||||||
|
License : BSD2
|
||||||
|
Stability : experimental
|
||||||
|
|
||||||
|
This module defines the core data type for the simplest
|
||||||
|
Neural network we support.
|
||||||
|
|
||||||
|
-}
|
||||||
module Grenade.Core.Network (
|
module Grenade.Core.Network (
|
||||||
Layer (..)
|
Layer (..)
|
||||||
, Network (..)
|
, Network (..)
|
||||||
@ -20,10 +26,12 @@ module Grenade.Core.Network (
|
|||||||
) where
|
) where
|
||||||
|
|
||||||
import Control.Monad.Random (MonadRandom)
|
import Control.Monad.Random (MonadRandom)
|
||||||
|
import Data.List ( foldl' )
|
||||||
|
import Data.Singletons
|
||||||
|
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | Learning parameters for stochastic gradient descent.
|
||||||
data LearningParameters = LearningParameters {
|
data LearningParameters = LearningParameters {
|
||||||
learningRate :: Double
|
learningRate :: Double
|
||||||
, learningMomentum :: Double
|
, learningMomentum :: Double
|
||||||
@ -33,35 +41,43 @@ data LearningParameters = LearningParameters {
|
|||||||
-- | Class for updating a layer. All layers implement this, and it is
|
-- | Class for updating a layer. All layers implement this, and it is
|
||||||
-- shape independent.
|
-- shape independent.
|
||||||
class Show x => UpdateLayer x where
|
class Show x => UpdateLayer x where
|
||||||
|
{-# MINIMAL runUpdate, createRandom #-}
|
||||||
-- | The type for the gradient for this layer.
|
-- | The type for the gradient for this layer.
|
||||||
-- Unit if there isn't a gradient to pass back.
|
-- Unit if there isn't a gradient to pass back.
|
||||||
type Gradient x :: *
|
type Gradient x :: *
|
||||||
-- | Update a layer with its gradient and learning parameters
|
-- | Update a layer with its gradient and learning parameters
|
||||||
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
||||||
|
|
||||||
-- | Create a random layer, many layers will use pure
|
-- | Create a random layer, many layers will use pure
|
||||||
createRandom :: MonadRandom m => m x
|
createRandom :: MonadRandom m => m x
|
||||||
|
|
||||||
|
-- | Update a layer with many Gradients
|
||||||
|
runUpdates :: LearningParameters -> x -> [Gradient x] -> x
|
||||||
|
runUpdates rate = foldl' (runUpdate rate)
|
||||||
|
|
||||||
-- | Class for a layer. All layers implement this, however, they don't
|
-- | Class for a layer. All layers implement this, however, they don't
|
||||||
-- need to implement it for all shapes, only ones which are appropriate.
|
-- need to implement it for all shapes, only ones which are appropriate.
|
||||||
class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
|
class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
|
||||||
-- | Used in training and scoring. Take the input from the previous
|
-- | Used in training and scoring. Take the input from the previous
|
||||||
-- layer, and give the output from this layer.
|
-- layer, and give the output from this layer.
|
||||||
runForwards :: x -> S' i -> S' o
|
runForwards :: x -> S i -> S o
|
||||||
-- | Back propagate a step. Takes the current layer, the input that the
|
-- | Back propagate a step. Takes the current layer, the input that the
|
||||||
-- layer gave from the input and the back propagated derivatives from
|
-- layer gave from the input and the back propagated derivatives from
|
||||||
-- the layer above.
|
-- the layer above.
|
||||||
-- Returns the gradient layer and the derivatives to push back further.
|
-- Returns the gradient layer and the derivatives to push back further.
|
||||||
runBackwards :: x -> S' i -> S' o -> (Gradient x, S' i)
|
runBackwards :: x -> S i -> S o -> (Gradient x, S i)
|
||||||
|
|
||||||
-- | Type of a network.
|
-- | Type of a network.
|
||||||
-- The [*] type specifies the types of the layers. This is needed for parallel
|
--
|
||||||
-- running and being all the gradients beck together.
|
-- The [*] type specifies the types of the layers.
|
||||||
|
--
|
||||||
-- The [Shape] type specifies the shapes of data passed between the layers.
|
-- The [Shape] type specifies the shapes of data passed between the layers.
|
||||||
-- Could be considered to be a heterogeneous list of layers which are able to
|
--
|
||||||
|
-- Can be considered to be a heterogeneous list of layers which are able to
|
||||||
-- transform the data shapes of the network.
|
-- transform the data shapes of the network.
|
||||||
data Network :: [*] -> [Shape] -> * where
|
data Network :: [*] -> [Shape] -> * where
|
||||||
O :: Layer x i o => !x -> Network '[x] '[i, o]
|
O :: (SingI i, SingI o, Layer x i o) => !x -> Network '[x] '[i, o]
|
||||||
(:~>) :: Layer x i h => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
|
(:~>) :: (SingI i, SingI h, Layer x i h) => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
|
||||||
infixr 5 :~>
|
infixr 5 :~>
|
||||||
|
|
||||||
instance Show (Network l h) where
|
instance Show (Network l h) where
|
||||||
@ -74,15 +90,14 @@ data Gradients :: [*] -> * where
|
|||||||
OG :: UpdateLayer x => Gradient x -> Gradients '[x]
|
OG :: UpdateLayer x => Gradient x -> Gradients '[x]
|
||||||
(:/>) :: UpdateLayer x => Gradient x -> Gradients xs -> Gradients (x ': xs)
|
(:/>) :: UpdateLayer x => Gradient x -> Gradients xs -> Gradients (x ': xs)
|
||||||
|
|
||||||
|
|
||||||
-- | A network can easily be created by hand with (:~>), but an easy way to initialise a random
|
-- | A network can easily be created by hand with (:~>), but an easy way to initialise a random
|
||||||
-- network is with the randomNetwork.
|
-- network is with the randomNetwork.
|
||||||
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
|
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
|
||||||
-- | Create a network of the types requested
|
-- | Create a network of the types requested
|
||||||
randomNetwork :: MonadRandom m => m (Network xs ss)
|
randomNetwork :: MonadRandom m => m (Network xs ss)
|
||||||
|
|
||||||
instance Layer x i o => CreatableNetwork (x ': '[]) (i ': o ': '[]) where
|
instance (SingI i, SingI o, Layer x i o) => CreatableNetwork (x ': '[]) (i ': o ': '[]) where
|
||||||
randomNetwork = O <$> createRandom
|
randomNetwork = O <$> createRandom
|
||||||
|
|
||||||
instance (Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
|
instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
|
||||||
randomNetwork = (:~>) <$> createRandom <*> randomNetwork
|
randomNetwork = (:~>) <$> createRandom <*> randomNetwork
|
||||||
|
@ -4,7 +4,16 @@
|
|||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-|
|
||||||
|
Module : Grenade.Core.Shape
|
||||||
|
Description : Core definition of the Shapes of data we understand
|
||||||
|
Copyright : (c) Huw Campbell, 2016-2017
|
||||||
|
License : BSD2
|
||||||
|
Stability : experimental
|
||||||
|
|
||||||
|
This module defines simple back propagation and training functions
|
||||||
|
for a network.
|
||||||
|
-}
|
||||||
module Grenade.Core.Runner (
|
module Grenade.Core.Runner (
|
||||||
train
|
train
|
||||||
, backPropagate
|
, backPropagate
|
||||||
@ -16,16 +25,22 @@ import Data.Singletons.Prelude
|
|||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
-- | Drive and network and collect its back propogated gradients.
|
-- | Perform reverse automatic differentiation on the network
|
||||||
backPropagate :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
|
-- for the current input and expected output.
|
||||||
=> Network layers shapes -> S' input -> S' output -> Gradients layers
|
--
|
||||||
|
-- /Note:/ The loss function pushed backwards is appropriate
|
||||||
|
-- for both regression and classification as a squared loss
|
||||||
|
-- or log-loss respectively. Other loss functions are not yet
|
||||||
|
-- implemented.
|
||||||
|
backPropagate :: forall shapes layers.
|
||||||
|
Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Gradients layers
|
||||||
backPropagate network input target =
|
backPropagate network input target =
|
||||||
fst $ go input network
|
fst $ go input network
|
||||||
where
|
where
|
||||||
go :: forall j js sublayers. (Head js ~ j, Last js ~ output)
|
go :: forall js sublayers. (Last js ~ Last shapes)
|
||||||
=> S' j -- ^ input vector
|
=> S (Head js) -- ^ input vector
|
||||||
-> Network sublayers js -- ^ network to train
|
-> Network sublayers js -- ^ network to train
|
||||||
-> (Gradients sublayers, S' j)
|
-> (Gradients sublayers, S (Head js))
|
||||||
-- handle input from the beginning, feeding upwards.
|
-- handle input from the beginning, feeding upwards.
|
||||||
go !x (layer :~> n)
|
go !x (layer :~> n)
|
||||||
= let y = runForwards layer x
|
= let y = runForwards layer x
|
||||||
@ -44,16 +59,7 @@ backPropagate network input target =
|
|||||||
|
|
||||||
in (OG layer', dWs)
|
in (OG layer', dWs)
|
||||||
|
|
||||||
-- | Update a network with new weights after training with an instance.
|
-- | Apply one step of stochastic gradient decent across the network.
|
||||||
train :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
|
|
||||||
=> LearningParameters -- ^ learning rate
|
|
||||||
-> Network layers shapes -- ^ network to train
|
|
||||||
-> S' input -> S' output -- ^ target vector
|
|
||||||
-> Network layers shapes
|
|
||||||
train rate network input output =
|
|
||||||
let grads = backPropagate network input output
|
|
||||||
in applyUpdate rate network grads
|
|
||||||
|
|
||||||
applyUpdate :: LearningParameters -> Network ls ss -> Gradients ls -> Network ls ss
|
applyUpdate :: LearningParameters -> Network ls ss -> Gradients ls -> Network ls ss
|
||||||
applyUpdate rate (O layer) (OG gradient)
|
applyUpdate rate (O layer) (OG gradient)
|
||||||
= O (runUpdate rate layer gradient)
|
= O (runUpdate rate layer gradient)
|
||||||
@ -62,9 +68,13 @@ applyUpdate rate (layer :~> rest) (gradient :/> grest)
|
|||||||
applyUpdate _ _ _
|
applyUpdate _ _ _
|
||||||
= error "Impossible for the gradients of a network to have a different length to the network"
|
= error "Impossible for the gradients of a network to have a different length to the network"
|
||||||
|
|
||||||
-- | Just forwards propagation with no training.
|
-- | Update a network with new weights after training with an instance.
|
||||||
runNet :: Network layers hs
|
train :: LearningParameters -> Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Network layers shapes
|
||||||
-> S' (Head hs) -- ^ input vector
|
train rate network input output =
|
||||||
-> S' (Last hs) -- ^ target vector
|
let grads = backPropagate network input output
|
||||||
|
in applyUpdate rate network grads
|
||||||
|
|
||||||
|
-- | Run the network with input and return the given output.
|
||||||
|
runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes)
|
||||||
runNet (layer :~> n) !x = let y = runForwards layer x in runNet n y
|
runNet (layer :~> n) !x = let y = runForwards layer x in runNet n y
|
||||||
runNet (O layer) !x = runForwards layer x
|
runNet (O layer) !x = runForwards layer x
|
||||||
|
@ -1,75 +1,171 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE KindSignatures #-}
|
{-# LANGUAGE KindSignatures #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE PolyKinds #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
-- Ghc 8.0 gives a warning on `(+) _ _ = error ...` but ghc 7.10 fails to
|
-- Ghc 8.0 gives a warning on `n2 _ _ = error ...` but ghc 7.10 fails to
|
||||||
-- compile without this default pattern.
|
-- compile without this default pattern.
|
||||||
{-# OPTIONS_GHC -fno-warn-overlapping-patterns #-}
|
{-# OPTIONS_GHC -fno-warn-overlapping-patterns #-}
|
||||||
|
|
||||||
|
{-|
|
||||||
|
Module : Grenade.Core.Shape
|
||||||
|
Description : Core definition of the Shapes of data we understand
|
||||||
|
Copyright : (c) Huw Campbell, 2016-2017
|
||||||
|
License : BSD2
|
||||||
|
Stability : experimental
|
||||||
|
|
||||||
|
This module defines the core data types for the shapes of data that
|
||||||
|
are understood by Grenade.
|
||||||
|
-}
|
||||||
module Grenade.Core.Shape (
|
module Grenade.Core.Shape (
|
||||||
Shape (..)
|
Shape (..)
|
||||||
, S' (..)
|
, S (..)
|
||||||
|
, randomOfShape
|
||||||
|
, fromStorable
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Control.DeepSeq (NFData (..))
|
||||||
|
import Control.Monad.Random ( MonadRandom, getRandom )
|
||||||
|
|
||||||
|
import Data.Singletons
|
||||||
import Data.Singletons.TypeLits
|
import Data.Singletons.TypeLits
|
||||||
|
import Data.Vector.Storable ( Vector )
|
||||||
|
import qualified Data.Vector.Storable as V
|
||||||
|
|
||||||
import GHC.TypeLits
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as H
|
||||||
import Numeric.LinearAlgebra.Static
|
import Numeric.LinearAlgebra.Static
|
||||||
|
import qualified Numeric.LinearAlgebra as NLA
|
||||||
|
|
||||||
-- | The current shapes we accept.
|
-- | The current shapes we accept.
|
||||||
-- at the moment this is just one, two, and three dimensional
|
-- at the moment this is just one, two, and three dimensional
|
||||||
-- Vectors/Matricies.
|
-- Vectors/Matricies.
|
||||||
data Shape =
|
data Shape
|
||||||
D1 Nat
|
= D1 Nat
|
||||||
| D2 Nat Nat
|
| D2 Nat Nat
|
||||||
| D3 Nat Nat Nat
|
| D3 Nat Nat Nat
|
||||||
|
|
||||||
instance Num (S' x) where
|
|
||||||
(+) (S1D' x) (S1D' y) = S1D' (x + y)
|
|
||||||
(+) (S2D' x) (S2D' y) = S2D' (x + y)
|
|
||||||
(+) (S3D' x) (S3D' y) = S3D' (x + y)
|
|
||||||
(+) _ _ = error "Impossible to have different constructors for the same shaped network"
|
|
||||||
|
|
||||||
(-) (S1D' x) (S1D' y) = S1D' (x - y)
|
|
||||||
(-) (S2D' x) (S2D' y) = S2D' (x - y)
|
|
||||||
(-) (S3D' x) (S3D' y) = S3D' (x - y)
|
|
||||||
(-) _ _ = error "Impossible to have different constructors for the same shaped network"
|
|
||||||
|
|
||||||
(*) (S1D' x) (S1D' y) = S1D' (x * y)
|
|
||||||
(*) (S2D' x) (S2D' y) = S2D' (x * y)
|
|
||||||
(*) (S3D' x) (S3D' y) = S3D' (x * y)
|
|
||||||
(*) _ _ = error "Impossible to have different constructors for the same shaped network"
|
|
||||||
|
|
||||||
abs (S1D' x) = S1D' (abs x)
|
|
||||||
abs (S2D' x) = S2D' (abs x)
|
|
||||||
abs (S3D' x) = S3D' (abs x)
|
|
||||||
|
|
||||||
signum (S1D' x) = S1D' (signum x)
|
|
||||||
signum (S2D' x) = S2D' (signum x)
|
|
||||||
signum (S3D' x) = S3D' (signum x)
|
|
||||||
|
|
||||||
fromInteger _ = error "Unimplemented: fromInteger on Shape"
|
|
||||||
|
|
||||||
-- | Given a Shape n, these are the possible data structures with that shape.
|
-- | Given a Shape n, these are the possible data structures with that shape.
|
||||||
-- All shapes are held in contiguous memory.
|
-- All shapes are held in contiguous memory.
|
||||||
-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
|
-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
|
||||||
data S' (n :: Shape) where
|
data S (n :: Shape) where
|
||||||
S1D' :: ( KnownNat o ) => R o -> S' ('D1 o)
|
S1D :: ( KnownNat o ) => R o -> S ('D1 o)
|
||||||
S2D' :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S' ('D2 rows columns)
|
S2D :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S ('D2 rows columns)
|
||||||
S3D' :: ( KnownNat rows
|
S3D :: ( KnownNat rows
|
||||||
, KnownNat columns
|
, KnownNat columns
|
||||||
, KnownNat depth
|
, KnownNat depth
|
||||||
, KnownNat (rows * depth)) => L (rows * depth) columns -> S' ('D3 rows columns depth)
|
, KnownNat (rows * depth)) => L (rows * depth) columns -> S ('D3 rows columns depth)
|
||||||
|
|
||||||
instance Show (S' n) where
|
deriving instance Show (S n)
|
||||||
show (S1D' a) = "S1D' " ++ show a
|
|
||||||
show (S2D' a) = "S2D' " ++ show a
|
instance SingI x => Num (S x) where
|
||||||
show (S3D' a) = "S3D' " ++ show a
|
(+) = n2 (+)
|
||||||
|
(-) = n2 (-)
|
||||||
|
(*) = n2 (*)
|
||||||
|
abs = n1 abs
|
||||||
|
signum = n1 signum
|
||||||
|
fromInteger x = case (sing :: Sing x) of
|
||||||
|
D1Sing -> S1D (konst $ fromInteger x)
|
||||||
|
D2Sing -> S2D (konst $ fromInteger x)
|
||||||
|
D3Sing -> S3D (konst $ fromInteger x)
|
||||||
|
|
||||||
|
instance SingI x => Fractional (S x) where
|
||||||
|
(/) = n2 (/)
|
||||||
|
recip = n1 recip
|
||||||
|
fromRational x = case (sing :: Sing x) of
|
||||||
|
D1Sing -> S1D (konst $ fromRational x)
|
||||||
|
D2Sing -> S2D (konst $ fromRational x)
|
||||||
|
D3Sing -> S3D (konst $ fromRational x)
|
||||||
|
|
||||||
|
instance SingI x => Floating (S x) where
|
||||||
|
pi = case (sing :: Sing x) of
|
||||||
|
D1Sing -> S1D (konst pi)
|
||||||
|
D2Sing -> S2D (konst pi)
|
||||||
|
D3Sing -> S3D (konst pi)
|
||||||
|
exp = n1 exp
|
||||||
|
log = n1 log
|
||||||
|
sqrt = n1 sqrt
|
||||||
|
(**) = n2 (**)
|
||||||
|
logBase = n2 logBase
|
||||||
|
sin = n1 sin
|
||||||
|
cos = n1 cos
|
||||||
|
tan = n1 tan
|
||||||
|
asin = n1 asin
|
||||||
|
acos = n1 acos
|
||||||
|
atan = n1 atan
|
||||||
|
sinh = n1 sinh
|
||||||
|
cosh = n1 cosh
|
||||||
|
tanh = n1 tanh
|
||||||
|
asinh = n1 asinh
|
||||||
|
acosh = n1 acosh
|
||||||
|
atanh = n1 atanh
|
||||||
|
|
||||||
|
-- Singletons
|
||||||
|
-- These could probably be derived with template haskell, but this seems
|
||||||
|
-- clear and makes adding the KnownNat constraints simple.
|
||||||
|
data instance Sing (n :: Shape) where
|
||||||
|
D1Sing :: KnownNat a => Sing ('D1 a)
|
||||||
|
D2Sing :: (KnownNat a, KnownNat b) => Sing ('D2 a b)
|
||||||
|
D3Sing :: (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => Sing ('D3 a b c)
|
||||||
|
|
||||||
|
instance KnownNat a => SingI ('D1 a) where
|
||||||
|
sing = D1Sing
|
||||||
|
instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
|
||||||
|
sing = D2Sing
|
||||||
|
instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
|
||||||
|
sing = D3Sing
|
||||||
|
|
||||||
|
--
|
||||||
|
-- I haven't made shapes strict, as sometimes they're not needed
|
||||||
|
-- (the last input gradient back for instance)
|
||||||
|
--
|
||||||
|
instance NFData (S x) where
|
||||||
|
rnf (S1D x) = rnf x
|
||||||
|
rnf (S2D x) = rnf x
|
||||||
|
rnf (S3D x) = rnf x
|
||||||
|
|
||||||
|
-- | Generate random data of the desired shape
|
||||||
|
randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x)
|
||||||
|
randomOfShape = do
|
||||||
|
seed :: Int <- getRandom
|
||||||
|
return $ case (sing :: Sing x) of
|
||||||
|
D1Sing -> S1D (randomVector seed Uniform * 2 - 1)
|
||||||
|
D2Sing -> S2D (uniformSample seed (-1) 1)
|
||||||
|
D3Sing -> S3D (uniformSample seed (-1) 1)
|
||||||
|
|
||||||
|
-- | Generate a shape from a Storable Vector.
|
||||||
|
--
|
||||||
|
-- Returns Nothing if the vector is of the wrong size.
|
||||||
|
fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x)
|
||||||
|
fromStorable xs = case sing :: Sing x of
|
||||||
|
D1Sing -> S1D <$> H.create xs
|
||||||
|
D2Sing -> S2D <$> mkL xs
|
||||||
|
D3Sing -> S3D <$> mkL xs
|
||||||
|
where
|
||||||
|
mkL :: forall rows columns. (KnownNat rows, KnownNat columns)
|
||||||
|
=> Vector Double -> Maybe (L rows columns)
|
||||||
|
mkL v =
|
||||||
|
let rows = fromIntegral $ natVal (Proxy :: Proxy rows)
|
||||||
|
columns = fromIntegral $ natVal (Proxy :: Proxy columns)
|
||||||
|
in if rows * columns == V.length v
|
||||||
|
then H.create $ NLA.reshape columns v
|
||||||
|
else Nothing
|
||||||
|
|
||||||
|
-- Helper function for creating the number instances
|
||||||
|
n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x
|
||||||
|
n1 f (S1D x) = S1D (f x)
|
||||||
|
n1 f (S2D x) = S2D (f x)
|
||||||
|
n1 f (S3D x) = S3D (f x)
|
||||||
|
|
||||||
|
-- Helper function for creating the number instances
|
||||||
|
n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x
|
||||||
|
n2 f (S1D x) (S1D y) = S1D (f x y)
|
||||||
|
n2 f (S2D x) (S2D y) = S2D (f x y)
|
||||||
|
n2 f (S3D x) (S3D y) = S3D (f x y)
|
||||||
|
n2 _ _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||||
|
37
src/Grenade/Graph/GraphNetwork.hs
Normal file
37
src/Grenade/Graph/GraphNetwork.hs
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE KindSignatures #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE PolyKinds #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
|
||||||
|
module Grenade.Graph.Network (
|
||||||
|
Layer (..)
|
||||||
|
, UpdateLayer (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Random (MonadRandom)
|
||||||
|
import Data.Singletons
|
||||||
|
import Data.Singletons.Prelude
|
||||||
|
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Core.Network ( UpdateLayer (..), Layer (..) )
|
||||||
|
|
||||||
|
-- | Type of a DAG network
|
||||||
|
|
||||||
|
data Fin :: Nat -> * where
|
||||||
|
Fin0 :: Fin (n + 1)
|
||||||
|
FinS :: Fin n -> Fin (n + 1)
|
||||||
|
|
||||||
|
data Edge :: Nat -> * where
|
||||||
|
Edge :: Shape -> Fin n -> Edge n
|
||||||
|
|
||||||
|
data Node a n where
|
||||||
|
Node :: a -> [Edge n] -> Node a n
|
@ -1,7 +1,5 @@
|
|||||||
{-# LANGUAGE BangPatterns #-}
|
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE StandaloneDeriving #-}
|
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
@ -9,9 +7,6 @@
|
|||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE PolyKinds #-}
|
|
||||||
{-# LANGUAGE PatternGuards #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Convolution (
|
module Grenade.Layers.Convolution (
|
||||||
Convolution (..)
|
Convolution (..)
|
||||||
, Convolution' (..)
|
, Convolution' (..)
|
||||||
@ -31,6 +26,7 @@ import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
|
|||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
import Grenade.Layers.Internal.Convolution
|
import Grenade.Layers.Internal.Convolution
|
||||||
|
import Grenade.Layers.Internal.Update
|
||||||
|
|
||||||
-- | A convolution layer for a neural network.
|
-- | A convolution layer for a neural network.
|
||||||
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
||||||
@ -43,12 +39,12 @@ import Grenade.Layers.Internal.Convolution
|
|||||||
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
||||||
--
|
--
|
||||||
-- One probably shouldn't build their own layer, but rather use the randomConvolution function.
|
-- One probably shouldn't build their own layer, but rather use the randomConvolution function.
|
||||||
data Convolution :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance.
|
data Convolution :: Nat -- Number of channels, for the first layer this could be RGB for instance.
|
||||||
-> Nat -- ^ Number of filters, this is the number of channels output by the layer.
|
-> Nat -- Number of filters, this is the number of channels output by the layer.
|
||||||
-> Nat -- ^ The number of rows in the kernel filter
|
-> Nat -- The number of rows in the kernel filter
|
||||||
-> Nat -- ^ The number of column in the kernel filter
|
-> Nat -- The number of column in the kernel filter
|
||||||
-> Nat -- ^ The row stride of the convolution filter
|
-> Nat -- The row stride of the convolution filter
|
||||||
-> Nat -- ^ The columns stride of the convolution filter
|
-> Nat -- The columns stride of the convolution filter
|
||||||
-> * where
|
-> * where
|
||||||
Convolution :: ( KnownNat channels
|
Convolution :: ( KnownNat channels
|
||||||
, KnownNat filters
|
, KnownNat filters
|
||||||
@ -58,16 +54,16 @@ data Convolution :: Nat -- ^ Number of channels, for the first layer this could
|
|||||||
, KnownNat strideColumns
|
, KnownNat strideColumns
|
||||||
, KnownNat kernelFlattened
|
, KnownNat kernelFlattened
|
||||||
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||||
=> !(L kernelFlattened filters) -- ^ The kernel filter weights
|
=> !(L kernelFlattened filters) -- The kernel filter weights
|
||||||
-> !(L kernelFlattened filters) -- ^ The last kernel update (or momentum)
|
-> !(L kernelFlattened filters) -- The last kernel update (or momentum)
|
||||||
-> Convolution channels filters kernelRows kernelColumns strideRows strideColumns
|
-> Convolution channels filters kernelRows kernelColumns strideRows strideColumns
|
||||||
|
|
||||||
data Convolution' :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance.
|
data Convolution' :: Nat -- Number of channels, for the first layer this could be RGB for instance.
|
||||||
-> Nat -- ^ Number of filters, this is the number of channels output by the layer.
|
-> Nat -- Number of filters, this is the number of channels output by the layer.
|
||||||
-> Nat -- ^ The number of rows in the kernel filter
|
-> Nat -- The number of rows in the kernel filter
|
||||||
-> Nat -- ^ The number of column in the kernel filter
|
-> Nat -- The number of column in the kernel filter
|
||||||
-> Nat -- ^ The row stride of the convolution filter
|
-> Nat -- The row stride of the convolution filter
|
||||||
-> Nat -- ^ The columns stride of the convolution filter
|
-> Nat -- The columns stride of the convolution filter
|
||||||
-> * where
|
-> * where
|
||||||
Convolution' :: ( KnownNat channels
|
Convolution' :: ( KnownNat channels
|
||||||
, KnownNat filters
|
, KnownNat filters
|
||||||
@ -77,7 +73,7 @@ data Convolution' :: Nat -- ^ Number of channels, for the first layer this could
|
|||||||
, KnownNat strideColumns
|
, KnownNat strideColumns
|
||||||
, KnownNat kernelFlattened
|
, KnownNat kernelFlattened
|
||||||
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||||
=> !(L kernelFlattened filters) -- ^ The kernel filter gradient
|
=> !(L kernelFlattened filters) -- The kernel filter gradient
|
||||||
-> Convolution' channels filters kernelRows kernelColumns strideRows strideColumns
|
-> Convolution' channels filters kernelRows kernelColumns strideRows strideColumns
|
||||||
|
|
||||||
instance Show (Convolution c f k k' s s') where
|
instance Show (Convolution c f k k' s s') where
|
||||||
@ -109,7 +105,7 @@ randomConvolution :: ( MonadRandom m
|
|||||||
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||||
=> m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
|
=> m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||||
randomConvolution = do
|
randomConvolution = do
|
||||||
s :: Int <- getRandom
|
s <- getRandom
|
||||||
let wN = uniformSample s (-1) 1
|
let wN = uniformSample s (-1) 1
|
||||||
mm = konst 0
|
mm = konst 0
|
||||||
return $ Convolution wN mm
|
return $ Convolution wN mm
|
||||||
@ -124,9 +120,7 @@ instance ( KnownNat channels
|
|||||||
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
||||||
type Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols) = (Convolution' channels filters kernelRows kernelCols strideRows strideCols)
|
type Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols) = (Convolution' channels filters kernelRows kernelCols strideRows strideCols)
|
||||||
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
|
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
|
||||||
let newMomentum = konst learningMomentum * oldMomentum - konst learningRate * kernelGradient
|
let (newKernel, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
|
||||||
regulariser = konst (learningRegulariser * learningRate) * oldKernel
|
|
||||||
newKernel = oldKernel + newMomentum - regulariser
|
|
||||||
in Convolution newKernel newMomentum
|
in Convolution newKernel newMomentum
|
||||||
|
|
||||||
createRandom = randomConvolution
|
createRandom = randomConvolution
|
||||||
@ -146,7 +140,7 @@ instance ( KnownNat kernelRows
|
|||||||
, KnownNat (kernelRows * kernelCols * 1)
|
, KnownNat (kernelRows * kernelCols * 1)
|
||||||
, KnownNat (outputRows * filters)
|
, KnownNat (outputRows * filters)
|
||||||
) => Layer (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where
|
) => Layer (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where
|
||||||
runForwards (Convolution kernel _) (S2D' input) =
|
runForwards (Convolution kernel _) (S2D input) =
|
||||||
let ex = extract input
|
let ex = extract input
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
@ -159,9 +153,9 @@ instance ( KnownNat kernelRows
|
|||||||
mt = c LA.<> ek
|
mt = c LA.<> ek
|
||||||
r = col2vid 1 1 1 1 ox oy mt
|
r = col2vid 1 1 1 1 ox oy mt
|
||||||
rs = fromJust . create $ r
|
rs = fromJust . create $ r
|
||||||
in S3D' rs
|
in S3D rs
|
||||||
|
|
||||||
runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
|
runBackwards (Convolution kernel _) (S2D input) (S3D dEdy) =
|
||||||
let ex = extract input
|
let ex = extract input
|
||||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
@ -183,7 +177,7 @@ instance ( KnownNat kernelRows
|
|||||||
dW = vs LA.<> tr ek
|
dW = vs LA.<> tr ek
|
||||||
|
|
||||||
xW = col2im kx ky sx sy ix iy dW
|
xW = col2im kx ky sx sy ix iy dW
|
||||||
in (Convolution' kN, S2D' . fromJust . create $ xW)
|
in (Convolution' kN, S2D . fromJust . create $ xW)
|
||||||
|
|
||||||
|
|
||||||
-- | A three dimensional image (or 2d with many channels) can have
|
-- | A three dimensional image (or 2d with many channels) can have
|
||||||
@ -203,7 +197,7 @@ instance ( KnownNat kernelRows
|
|||||||
, KnownNat (kernelRows * kernelCols * channels)
|
, KnownNat (kernelRows * kernelCols * channels)
|
||||||
, KnownNat (outputRows * filters)
|
, KnownNat (outputRows * filters)
|
||||||
) => Layer (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where
|
) => Layer (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where
|
||||||
runForwards (Convolution kernel _) (S3D' input) =
|
runForwards (Convolution kernel _) (S3D input) =
|
||||||
let ex = extract input
|
let ex = extract input
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
@ -219,8 +213,8 @@ instance ( KnownNat kernelRows
|
|||||||
mt = c LA.<> ek
|
mt = c LA.<> ek
|
||||||
r = col2vid 1 1 1 1 ox oy mt
|
r = col2vid 1 1 1 1 ox oy mt
|
||||||
rs = fromJust . create $ r
|
rs = fromJust . create $ r
|
||||||
in S3D' rs
|
in S3D rs
|
||||||
runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
|
runBackwards (Convolution kernel _) (S3D input) (S3D dEdy) =
|
||||||
let ex = extract input
|
let ex = extract input
|
||||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||||
@ -243,4 +237,4 @@ instance ( KnownNat kernelRows
|
|||||||
dW = vs LA.<> tr ek
|
dW = vs LA.<> tr ek
|
||||||
|
|
||||||
xW = col2vid kx ky sx sy ix iy dW
|
xW = col2vid kx ky sx sy ix iy dW
|
||||||
in (Convolution' kN, S3D' . fromJust . create $ xW)
|
in (Convolution' kN, S3D . fromJust . create $ xW)
|
||||||
|
@ -4,10 +4,6 @@
|
|||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
|
||||||
{-# LANGUAGE PolyKinds #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Crop (
|
module Grenade.Layers.Crop (
|
||||||
Crop (..)
|
Crop (..)
|
||||||
) where
|
) where
|
||||||
@ -50,19 +46,19 @@ instance ( KnownNat cropLeft
|
|||||||
, (inputRows - cropTop - cropBottom) ~ outputRows
|
, (inputRows - cropTop - cropBottom) ~ outputRows
|
||||||
, (inputColumns - cropLeft - cropRight) ~ outputColumns
|
, (inputColumns - cropLeft - cropRight) ~ outputColumns
|
||||||
) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||||
runForwards Crop (S2D' input) =
|
runForwards Crop (S2D input) =
|
||||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||||
nrows = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
nrows = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
ncols = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
ncols = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
||||||
m = extract input
|
m = extract input
|
||||||
r = subMatrix (cropt, cropl) (nrows, ncols) m
|
r = subMatrix (cropt, cropl) (nrows, ncols) m
|
||||||
in S2D' . fromJust . create $ r
|
in S2D . fromJust . create $ r
|
||||||
runBackwards _ _ (S2D' dEdy) =
|
runBackwards _ _ (S2D dEdy) =
|
||||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||||
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
|
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
|
||||||
cropb = fromIntegral $ natVal (Proxy :: Proxy cropBottom)
|
cropb = fromIntegral $ natVal (Proxy :: Proxy cropBottom)
|
||||||
eo = extract dEdy
|
eo = extract dEdy
|
||||||
vs = diagBlock [konst 0 (cropt,cropl), eo, konst 0 (cropb,cropr)]
|
vs = diagBlock [konst 0 (cropt,cropl), eo, konst 0 (cropb,cropr)]
|
||||||
in ((), S2D' . fromJust . create $ vs)
|
in ((), S2D . fromJust . create $ vs)
|
||||||
|
@ -1,12 +1,7 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
{-# LANGUAGE LambdaCase #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Dropout (
|
module Grenade.Layers.Dropout (
|
||||||
Dropout (..)
|
Dropout (..)
|
||||||
, randomDropout
|
, randomDropout
|
||||||
@ -45,7 +40,7 @@ randomDropout rate = do
|
|||||||
return $ Dropout xs
|
return $ Dropout xs
|
||||||
|
|
||||||
instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where
|
instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where
|
||||||
runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops
|
runForwards (Dropout drops) (S1D x) = S1D $ x * drops
|
||||||
runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x
|
runForwards (Pass rate) (S1D x)= S1D $ dvmap (* (1 - rate)) x
|
||||||
runBackwards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops)
|
runBackwards (Dropout drops) _ (S1D x) = ((), S1D $ x * drops)
|
||||||
runBackwards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x)
|
runBackwards (Pass rate) _ (S1D x) = ((), S1D $ dvmap (* (1 - rate)) x)
|
||||||
|
@ -1,13 +1,8 @@
|
|||||||
{-# LANGUAGE BangPatterns #-}
|
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE StandaloneDeriving #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Flatten (
|
module Grenade.Layers.Flatten (
|
||||||
FlattenLayer (..)
|
FlattenLayer (..)
|
||||||
) where
|
) where
|
||||||
@ -16,11 +11,16 @@ import Data.Singletons.TypeLits
|
|||||||
import GHC.TypeLits
|
import GHC.TypeLits
|
||||||
|
|
||||||
import Numeric.LinearAlgebra.Static
|
import Numeric.LinearAlgebra.Static
|
||||||
import Numeric.LinearAlgebra.Data as LA (flatten, toList)
|
import Numeric.LinearAlgebra.Data as LA ( flatten )
|
||||||
|
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
|
|
||||||
|
-- | Flatten Layer
|
||||||
|
--
|
||||||
|
-- Flattens input down to D1 from either 2D or 3D data.
|
||||||
|
--
|
||||||
|
-- Can also be used to turn a 3D image with only one channel into a 2D image.
|
||||||
data FlattenLayer = FlattenLayer
|
data FlattenLayer = FlattenLayer
|
||||||
deriving Show
|
deriving Show
|
||||||
|
|
||||||
@ -29,11 +29,18 @@ instance UpdateLayer FlattenLayer where
|
|||||||
runUpdate _ _ _ = FlattenLayer
|
runUpdate _ _ _ = FlattenLayer
|
||||||
createRandom = return FlattenLayer
|
createRandom = return FlattenLayer
|
||||||
|
|
||||||
|
|
||||||
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * z)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * z)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
||||||
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
|
runForwards _ (S2D y) = fromJust' . fromStorable . flatten . extract $ y
|
||||||
runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
|
runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
|
||||||
|
|
||||||
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
||||||
runForwards _ (S3D' y) = S1D' . fromList . toList . flatten . extract $ y
|
runForwards _ (S3D y) = fromJust' . fromStorable . flatten . extract $ y
|
||||||
runBackwards _ _ (S1D' y) = ((), S3D' . fromList . toList . unwrap $ y)
|
runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
|
||||||
|
|
||||||
|
instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer FlattenLayer ('D3 x y z) ('D2 x y) where
|
||||||
|
runForwards _ (S3D y) = S2D y
|
||||||
|
runBackwards _ _ (S2D y) = ((), S3D y)
|
||||||
|
|
||||||
|
fromJust' :: Maybe x -> x
|
||||||
|
fromJust' (Just x) = x
|
||||||
|
fromJust' Nothing = error $ "FlattenLayer error: data shape couldn't be converted."
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.FullyConnected (
|
module Grenade.Layers.FullyConnected (
|
||||||
FullyConnected (..)
|
FullyConnected (..)
|
||||||
, randomFullyConnected
|
, randomFullyConnected
|
||||||
@ -20,6 +17,8 @@ import Numeric.LinearAlgebra.Static
|
|||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
import Grenade.Layers.Internal.Update
|
||||||
|
|
||||||
-- | A basic fully connected (or inner product) neural network layer.
|
-- | A basic fully connected (or inner product) neural network layer.
|
||||||
data FullyConnected i o = FullyConnected
|
data FullyConnected i o = FullyConnected
|
||||||
!(R o) -- Bias neuron weights
|
!(R o) -- Bias neuron weights
|
||||||
@ -38,32 +37,29 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
|
|||||||
type Gradient (FullyConnected i o) = (FullyConnected' i o)
|
type Gradient (FullyConnected i o) = (FullyConnected' i o)
|
||||||
|
|
||||||
runUpdate LearningParameters {..} (FullyConnected oldBias oldBiasMomentum oldActivations oldMomentum) (FullyConnected' biasGradient activationGradient) =
|
runUpdate LearningParameters {..} (FullyConnected oldBias oldBiasMomentum oldActivations oldMomentum) (FullyConnected' biasGradient activationGradient) =
|
||||||
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
|
let (newBias, newBiasMomentum) = decendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum
|
||||||
newBias = oldBias + newBiasMomentum
|
(newActivations, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum
|
||||||
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
|
|
||||||
regulariser = konst (learningRegulariser * learningRate) * oldActivations
|
|
||||||
newActivations = oldActivations + newMomentum - regulariser
|
|
||||||
in FullyConnected newBias newBiasMomentum newActivations newMomentum
|
in FullyConnected newBias newBiasMomentum newActivations newMomentum
|
||||||
|
|
||||||
createRandom = randomFullyConnected
|
createRandom = randomFullyConnected
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
|
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
|
||||||
-- Do a matrix vector multiplication and return the result.
|
-- Do a matrix vector multiplication and return the result.
|
||||||
runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)
|
runForwards (FullyConnected wB _ wN _) (S1D v) = S1D (wB + wN #> v)
|
||||||
|
|
||||||
-- Run a backpropogation step for a full connected layer.
|
-- Run a backpropogation step for a full connected layer.
|
||||||
runBackwards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
|
runBackwards (FullyConnected _ _ wN _) (S1D x) (S1D dEdy) =
|
||||||
let wB' = dEdy
|
let wB' = dEdy
|
||||||
mm' = dEdy `outer` x
|
mm' = dEdy `outer` x
|
||||||
-- calcluate derivatives for next step
|
-- calcluate derivatives for next step
|
||||||
dWs = tr wN #> dEdy
|
dWs = tr wN #> dEdy
|
||||||
in (FullyConnected' wB' mm', S1D' dWs)
|
in (FullyConnected' wB' mm', S1D dWs)
|
||||||
|
|
||||||
randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o)
|
randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o)
|
||||||
=> m (FullyConnected i o)
|
=> m (FullyConnected i o)
|
||||||
randomFullyConnected = do
|
randomFullyConnected = do
|
||||||
s1 :: Int <- getRandom
|
s1 <- getRandom
|
||||||
s2 :: Int <- getRandom
|
s2 <- getRandom
|
||||||
let wB = randomVector s1 Uniform * 2 - 1
|
let wB = randomVector s1 Uniform * 2 - 1
|
||||||
wN = uniformSample s2 (-1) 1
|
wN = uniformSample s2 (-1) 1
|
||||||
bm = konst 0
|
bm = konst 0
|
||||||
|
@ -1,16 +1,11 @@
|
|||||||
{-# LANGUAGE BangPatterns #-}
|
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE KindSignatures #-}
|
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE PolyKinds #-}
|
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
|
||||||
|
|
||||||
module Grenade.Layers.Fuse (
|
module Grenade.Layers.Fuse (
|
||||||
Fuse (..)
|
Fuse (..)
|
||||||
) where
|
) where
|
||||||
@ -42,11 +37,11 @@ instance (Layer x i h, Layer y h o) => UpdateLayer (Fuse x y i h o) where
|
|||||||
|
|
||||||
instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
|
instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
|
||||||
runForwards (x :$$ y) input =
|
runForwards (x :$$ y) input =
|
||||||
let yInput :: S' h = runForwards x input
|
let yInput :: S h = runForwards x input
|
||||||
in runForwards y yInput
|
in runForwards y yInput
|
||||||
|
|
||||||
runBackwards (x :$$ y) input backGradient =
|
runBackwards (x :$$ y) input backGradient =
|
||||||
let yInput :: S' h = runForwards x input
|
let yInput :: S h = runForwards x input
|
||||||
(y', yGrad) = runBackwards y yInput backGradient
|
(y', yGrad) = runBackwards y yInput backGradient
|
||||||
(x', xGrad) = runBackwards x input yGrad
|
(x', xGrad) = runBackwards x input yGrad
|
||||||
in ((x', y'), xGrad)
|
in ((x', y'), xGrad)
|
||||||
|
@ -6,7 +6,9 @@ module Grenade.Layers.Internal.Convolution (
|
|||||||
, vid2col
|
, vid2col
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Foreign ( mallocForeignPtrArray0, withForeignPtr )
|
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
|
||||||
|
|
||||||
|
import Foreign ( mallocForeignPtrArray, withForeignPtr )
|
||||||
import Foreign.Ptr ( Ptr )
|
import Foreign.Ptr ( Ptr )
|
||||||
|
|
||||||
import Numeric.LinearAlgebra ( Matrix, flatten, rows, cols )
|
import Numeric.LinearAlgebra ( Matrix, flatten, rows, cols )
|
||||||
@ -28,19 +30,19 @@ col2im_c :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Ma
|
|||||||
col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol =
|
col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol =
|
||||||
let vec = flatten dataCol
|
let vec = flatten dataCol
|
||||||
in unsafePerformIO $ do
|
in unsafePerformIO $ do
|
||||||
outPtr <- mallocForeignPtrArray0 (height * width * channels)
|
outPtr <- mallocForeignPtrArray (height * width * channels)
|
||||||
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
|
let (inPtr, _) = U.unsafeToForeignPtr0 vec
|
||||||
|
|
||||||
withForeignPtr inPtr $ \inPtr' ->
|
withForeignPtr inPtr $ \inPtr' ->
|
||||||
withForeignPtr outPtr $ \outPtr' ->
|
withForeignPtr outPtr $ \outPtr' ->
|
||||||
col2im_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
col2im_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||||
|
|
||||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
|
let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
|
||||||
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
||||||
|
|
||||||
foreign import ccall unsafe
|
foreign import ccall unsafe
|
||||||
col2im_cpu
|
col2im_cpu
|
||||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||||
|
|
||||||
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid =
|
vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid =
|
||||||
@ -63,16 +65,16 @@ im2col_c channels height width kernelRows kernelColumns strideRows strideColumns
|
|||||||
kernelSize = kernelRows * kernelColumns
|
kernelSize = kernelRows * kernelColumns
|
||||||
numberOfPatches = rowOut * colOut
|
numberOfPatches = rowOut * colOut
|
||||||
in unsafePerformIO $ do
|
in unsafePerformIO $ do
|
||||||
outPtr <- mallocForeignPtrArray0 (numberOfPatches * kernelSize * channels)
|
outPtr <- mallocForeignPtrArray (numberOfPatches * kernelSize * channels)
|
||||||
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
|
let (inPtr, _) = U.unsafeToForeignPtr0 vec
|
||||||
|
|
||||||
withForeignPtr inPtr $ \inPtr' ->
|
withForeignPtr inPtr $ \inPtr' ->
|
||||||
withForeignPtr outPtr $ \outPtr' ->
|
withForeignPtr outPtr $ \outPtr' ->
|
||||||
im2col_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
im2col_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||||
|
|
||||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * kernelSize * channels)
|
let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * kernelSize * channels)
|
||||||
return $ U.matrixFromVector U.RowMajor numberOfPatches (kernelSize * channels) matVec
|
return $ U.matrixFromVector U.RowMajor numberOfPatches (kernelSize * channels) matVec
|
||||||
|
|
||||||
foreign import ccall unsafe
|
foreign import ccall unsafe
|
||||||
im2col_cpu
|
im2col_cpu
|
||||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||||
|
@ -4,7 +4,9 @@ module Grenade.Layers.Internal.Pooling (
|
|||||||
, poolBackward
|
, poolBackward
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Foreign ( mallocForeignPtrArray0, withForeignPtr )
|
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
|
||||||
|
|
||||||
|
import Foreign ( mallocForeignPtrArray, withForeignPtr )
|
||||||
import Foreign.Ptr ( Ptr )
|
import Foreign.Ptr ( Ptr )
|
||||||
|
|
||||||
import Numeric.LinearAlgebra ( Matrix , flatten )
|
import Numeric.LinearAlgebra ( Matrix , flatten )
|
||||||
@ -19,37 +21,37 @@ poolForward channels height width kernelRows kernelColumns strideRows strideColu
|
|||||||
colOut = (width - kernelColumns) `div` strideColumns + 1
|
colOut = (width - kernelColumns) `div` strideColumns + 1
|
||||||
numberOfPatches = rowOut * colOut
|
numberOfPatches = rowOut * colOut
|
||||||
in unsafePerformIO $ do
|
in unsafePerformIO $ do
|
||||||
outPtr <- mallocForeignPtrArray0 (numberOfPatches * channels)
|
outPtr <- mallocForeignPtrArray (numberOfPatches * channels)
|
||||||
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
|
let (inPtr, _) = U.unsafeToForeignPtr0 vec
|
||||||
|
|
||||||
withForeignPtr inPtr $ \inPtr' ->
|
withForeignPtr inPtr $ \inPtr' ->
|
||||||
withForeignPtr outPtr $ \outPtr' ->
|
withForeignPtr outPtr $ \outPtr' ->
|
||||||
pool_forwards_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
pool_forwards_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||||
|
|
||||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * channels)
|
let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * channels)
|
||||||
return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec
|
return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec
|
||||||
|
|
||||||
foreign import ccall unsafe
|
foreign import ccall unsafe
|
||||||
pool_forwards_cpu
|
pool_forwards_cpu
|
||||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||||
|
|
||||||
poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||||
poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad =
|
poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad =
|
||||||
let vecIm = flatten dataIm
|
let vecIm = flatten dataIm
|
||||||
vecGrad = flatten dataGrad
|
vecGrad = flatten dataGrad
|
||||||
in unsafePerformIO $ do
|
in unsafePerformIO $ do
|
||||||
outPtr <- mallocForeignPtrArray0 (height * width * channels)
|
outPtr <- mallocForeignPtrArray (height * width * channels)
|
||||||
let (imPtr, imOffset, _) = U.unsafeToForeignPtr vecIm
|
let (imPtr, _) = U.unsafeToForeignPtr0 vecIm
|
||||||
let (gradPtr, gradOffset, _) = U.unsafeToForeignPtr vecGrad
|
let (gradPtr, _) = U.unsafeToForeignPtr0 vecGrad
|
||||||
|
|
||||||
withForeignPtr imPtr $ \imPtr' ->
|
withForeignPtr imPtr $ \imPtr' ->
|
||||||
withForeignPtr gradPtr $ \gradPtr' ->
|
withForeignPtr gradPtr $ \gradPtr' ->
|
||||||
withForeignPtr outPtr $ \outPtr' ->
|
withForeignPtr outPtr $ \outPtr' ->
|
||||||
pool_backwards_cpu imPtr' imOffset gradPtr' gradOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
pool_backwards_cpu imPtr' gradPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||||
|
|
||||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
|
let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
|
||||||
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
||||||
|
|
||||||
foreign import ccall unsafe
|
foreign import ccall unsafe
|
||||||
pool_backwards_cpu
|
pool_backwards_cpu
|
||||||
:: Ptr Double -> Int -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
:: Ptr Double -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||||
|
70
src/Grenade/Layers/Internal/Update.hs
Normal file
70
src/Grenade/Layers/Internal/Update.hs
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
{-# LANGUAGE ForeignFunctionInterface #-}
|
||||||
|
module Grenade.Layers.Internal.Update (
|
||||||
|
decendMatrix
|
||||||
|
, decendVector
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Maybe ( fromJust )
|
||||||
|
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
|
||||||
|
|
||||||
|
import Foreign ( mallocForeignPtrArray, withForeignPtr )
|
||||||
|
import Foreign.Ptr ( Ptr )
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra ( Vector, flatten )
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
import qualified Numeric.LinearAlgebra.Devel as U
|
||||||
|
|
||||||
|
import System.IO.Unsafe ( unsafePerformIO )
|
||||||
|
|
||||||
|
decendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
|
||||||
|
decendMatrix rate momentum regulariser weights gradient lastUpdate =
|
||||||
|
let (rows, cols) = size weights
|
||||||
|
len = rows * cols
|
||||||
|
-- Most gradients come in in ColumnMajor,
|
||||||
|
-- so we'll transpose here before flattening them
|
||||||
|
-- into a vector to prevent a copy.
|
||||||
|
--
|
||||||
|
-- This gives ~15% speed improvement for LSTMs.
|
||||||
|
weights' = flatten . tr . extract $ weights
|
||||||
|
gradient' = flatten . tr . extract $ gradient
|
||||||
|
lastUpdate' = flatten . tr . extract $ lastUpdate
|
||||||
|
(vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
|
||||||
|
|
||||||
|
-- Note that it's ColumnMajor, as we did a transpose before
|
||||||
|
-- using the internal vectors.
|
||||||
|
mw = U.matrixFromVector U.ColumnMajor rows cols vw
|
||||||
|
mm = U.matrixFromVector U.ColumnMajor rows cols vm
|
||||||
|
in (fromJust . create $ mw, fromJust . create $ mm)
|
||||||
|
|
||||||
|
decendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
|
||||||
|
decendVector rate momentum regulariser weights gradient lastUpdate =
|
||||||
|
let len = size weights
|
||||||
|
weights' = extract weights
|
||||||
|
gradient' = extract gradient
|
||||||
|
lastUpdate' = extract lastUpdate
|
||||||
|
(vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
|
||||||
|
in (fromJust $ create vw, fromJust $ create vm)
|
||||||
|
|
||||||
|
decendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
|
||||||
|
decendUnsafe len rate momentum regulariser weights gradient lastUpdate =
|
||||||
|
unsafePerformIO $ do
|
||||||
|
outWPtr <- mallocForeignPtrArray len
|
||||||
|
outMPtr <- mallocForeignPtrArray len
|
||||||
|
let (wPtr, _) = U.unsafeToForeignPtr0 weights
|
||||||
|
let (gPtr, _) = U.unsafeToForeignPtr0 gradient
|
||||||
|
let (lPtr, _) = U.unsafeToForeignPtr0 lastUpdate
|
||||||
|
|
||||||
|
withForeignPtr wPtr $ \wPtr' ->
|
||||||
|
withForeignPtr gPtr $ \gPtr' ->
|
||||||
|
withForeignPtr lPtr $ \lPtr' ->
|
||||||
|
withForeignPtr outWPtr $ \outWPtr' ->
|
||||||
|
withForeignPtr outMPtr $ \outMPtr' ->
|
||||||
|
decend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'
|
||||||
|
|
||||||
|
return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len)
|
||||||
|
|
||||||
|
foreign import ccall unsafe
|
||||||
|
decend_cpu
|
||||||
|
:: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO ()
|
||||||
|
|
@ -1,10 +1,7 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Logit (
|
module Grenade.Layers.Logit (
|
||||||
Logit (..)
|
Logit (..)
|
||||||
) where
|
) where
|
||||||
@ -27,17 +24,16 @@ instance UpdateLayer Logit where
|
|||||||
createRandom = return Logit
|
createRandom = return Logit
|
||||||
|
|
||||||
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
|
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
|
||||||
runForwards _ (S1D' y) = S1D' (logistic y)
|
runForwards _ (S1D y) = S1D (logistic y)
|
||||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
|
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (logistic' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
|
instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
|
||||||
runForwards _ (S2D' y) = S2D' (logistic y)
|
runForwards _ (S2D y) = S2D (logistic y)
|
||||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
|
runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (logistic' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
|
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
|
||||||
runForwards _ (S3D' y) = S3D' (logistic y)
|
runForwards _ (S3D y) = S3D (logistic y)
|
||||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (logistic' y * dEdy))
|
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (logistic' y * dEdy))
|
||||||
|
|
||||||
|
|
||||||
logistic :: Floating a => a -> a
|
logistic :: Floating a => a -> a
|
||||||
logistic x = 1 / (1 + exp (-x))
|
logistic x = 1 / (1 + exp (-x))
|
||||||
|
@ -4,10 +4,6 @@
|
|||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
|
||||||
{-# LANGUAGE PolyKinds #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Pad (
|
module Grenade.Layers.Pad (
|
||||||
Pad (..)
|
Pad (..)
|
||||||
) where
|
) where
|
||||||
@ -50,19 +46,19 @@ instance ( KnownNat padLeft
|
|||||||
, (inputRows + padTop + padBottom) ~ outputRows
|
, (inputRows + padTop + padBottom) ~ outputRows
|
||||||
, (inputColumns + padLeft + padRight) ~ outputColumns
|
, (inputColumns + padLeft + padRight) ~ outputColumns
|
||||||
) => Layer (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
) => Layer (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||||
runForwards Pad (S2D' input) =
|
runForwards Pad (S2D input) =
|
||||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||||
padr = fromIntegral $ natVal (Proxy :: Proxy padRight)
|
padr = fromIntegral $ natVal (Proxy :: Proxy padRight)
|
||||||
padb = fromIntegral $ natVal (Proxy :: Proxy padBottom)
|
padb = fromIntegral $ natVal (Proxy :: Proxy padBottom)
|
||||||
m = extract input
|
m = extract input
|
||||||
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
|
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
|
||||||
in S2D' . fromJust . create $ r
|
in S2D . fromJust . create $ r
|
||||||
runBackwards Pad _ (S2D' dEdy) =
|
runBackwards Pad _ (S2D dEdy) =
|
||||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||||
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
m = extract dEdy
|
m = extract dEdy
|
||||||
vs = subMatrix (padt, padl) (nrows, ncols) m
|
vs = subMatrix (padt, padl) (nrows, ncols) m
|
||||||
in ((), S2D' . fromJust . create $ vs)
|
in ((), S2D . fromJust . create $ vs)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
{-# LANGUAGE BangPatterns #-}
|
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE StandaloneDeriving #-}
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
@ -6,10 +5,7 @@
|
|||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE PolyKinds #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Pooling (
|
module Grenade.Layers.Pooling (
|
||||||
Pooling (..)
|
Pooling (..)
|
||||||
) where
|
) where
|
||||||
@ -55,7 +51,7 @@ instance ( KnownNat kernelRows
|
|||||||
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||||
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||||
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||||
runForwards Pooling (S2D' input) =
|
runForwards Pooling (S2D input) =
|
||||||
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
@ -65,8 +61,8 @@ instance ( KnownNat kernelRows
|
|||||||
ex = extract input
|
ex = extract input
|
||||||
r = poolForward 1 height width kx ky sx sy ex
|
r = poolForward 1 height width kx ky sx sy ex
|
||||||
rs = fromJust . create $ r
|
rs = fromJust . create $ r
|
||||||
in S2D' $ rs
|
in S2D $ rs
|
||||||
runBackwards Pooling (S2D' input) (S2D' dEdy) =
|
runBackwards Pooling (S2D input) (S2D dEdy) =
|
||||||
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
@ -76,7 +72,7 @@ instance ( KnownNat kernelRows
|
|||||||
ex = extract input
|
ex = extract input
|
||||||
eo = extract dEdy
|
eo = extract dEdy
|
||||||
vs = poolBackward 1 height width kx ky sx sy ex eo
|
vs = poolBackward 1 height width kx ky sx sy ex eo
|
||||||
in ((), S2D' . fromJust . create $ vs)
|
in ((), S2D . fromJust . create $ vs)
|
||||||
|
|
||||||
|
|
||||||
-- | A three dimensional image can be pooled on each layer.
|
-- | A three dimensional image can be pooled on each layer.
|
||||||
@ -93,7 +89,7 @@ instance ( KnownNat kernelRows
|
|||||||
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||||
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||||
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
|
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
|
||||||
runForwards Pooling (S3D' input) =
|
runForwards Pooling (S3D input) =
|
||||||
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
@ -104,8 +100,8 @@ instance ( KnownNat kernelRows
|
|||||||
ex = extract input
|
ex = extract input
|
||||||
r = poolForward ch ix iy kx ky sx sy ex
|
r = poolForward ch ix iy kx ky sx sy ex
|
||||||
rs = fromJust . create $ r
|
rs = fromJust . create $ r
|
||||||
in S3D' rs
|
in S3D rs
|
||||||
runBackwards Pooling (S3D' input) (S3D' dEdy) =
|
runBackwards Pooling (S3D input) (S3D dEdy) =
|
||||||
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
@ -116,4 +112,4 @@ instance ( KnownNat kernelRows
|
|||||||
ex = extract input
|
ex = extract input
|
||||||
eo = extract dEdy
|
eo = extract dEdy
|
||||||
vs = poolBackward ch ix iy kx ky sx sy ex eo
|
vs = poolBackward ch ix iy kx ky sx sy ex eo
|
||||||
in ((), S3D' . fromJust . create $ vs)
|
in ((), S3D . fromJust . create $ vs)
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Relu (
|
module Grenade.Layers.Relu (
|
||||||
Relu (..)
|
Relu (..)
|
||||||
) where
|
) where
|
||||||
@ -27,25 +24,25 @@ instance UpdateLayer Relu where
|
|||||||
createRandom = return Relu
|
createRandom = return Relu
|
||||||
|
|
||||||
instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
|
instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
|
||||||
runForwards _ (S1D' y) = S1D' (relu y)
|
runForwards _ (S1D y) = S1D (relu y)
|
||||||
where
|
where
|
||||||
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
|
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
|
||||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
|
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (relu' y * dEdy))
|
||||||
where
|
where
|
||||||
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
|
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
|
instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
|
||||||
runForwards _ (S2D' y) = S2D' (relu y)
|
runForwards _ (S2D y) = S2D (relu y)
|
||||||
where
|
where
|
||||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
|
runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (relu' y * dEdy))
|
||||||
where
|
where
|
||||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j k) where
|
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j k) where
|
||||||
runForwards _ (S3D' y) = S3D' (relu y)
|
runForwards _ (S3D y) = S3D (relu y)
|
||||||
where
|
where
|
||||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (relu' y * dEdy))
|
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (relu' y * dEdy))
|
||||||
where
|
where
|
||||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
|
||||||
|
|
||||||
module Grenade.Layers.Tanh (
|
module Grenade.Layers.Tanh (
|
||||||
Tanh (..)
|
Tanh (..)
|
||||||
) where
|
) where
|
||||||
@ -24,16 +21,16 @@ instance UpdateLayer Tanh where
|
|||||||
createRandom = return Tanh
|
createRandom = return Tanh
|
||||||
|
|
||||||
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
|
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
|
||||||
runForwards _ (S1D' y) = S1D' (tanh y)
|
runForwards _ (S1D y) = S1D (tanh y)
|
||||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
|
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (tanh' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
|
instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
|
||||||
runForwards _ (S2D' y) = S2D' (tanh y)
|
runForwards _ (S2D y) = S2D (tanh y)
|
||||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
|
runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (tanh' y * dEdy))
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
|
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
|
||||||
runForwards _ (S3D' y) = S3D' (tanh y)
|
runForwards _ (S3D y) = S3D (tanh y)
|
||||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (tanh' y * dEdy))
|
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (tanh' y * dEdy))
|
||||||
|
|
||||||
tanh' :: (Floating a) => a -> a
|
tanh' :: (Floating a) => a -> a
|
||||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
||||||
|
9
src/Grenade/Recurrent.hs
Normal file
9
src/Grenade/Recurrent.hs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
module Grenade.Recurrent (
|
||||||
|
module X
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Grenade.Recurrent.Core.Network as X
|
||||||
|
import Grenade.Recurrent.Core.Runner as X
|
||||||
|
import Grenade.Recurrent.Layers.BasicRecurrent as X
|
||||||
|
import Grenade.Recurrent.Layers.LSTM as X
|
||||||
|
import Grenade.Recurrent.Layers.Trivial as X
|
98
src/Grenade/Recurrent/Core/Network.hs
Normal file
98
src/Grenade/Recurrent/Core/Network.hs
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
{-# LANGUAGE EmptyDataDecls #-}
|
||||||
|
module Grenade.Recurrent.Core.Network (
|
||||||
|
Recurrent
|
||||||
|
, FeedForward
|
||||||
|
, RecurrentLayer (..)
|
||||||
|
, RecurrentUpdateLayer (..)
|
||||||
|
, RecurrentNetwork (..)
|
||||||
|
, RecurrentInputs (..)
|
||||||
|
, CreatableRecurrent (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
|
||||||
|
import Control.Monad.Random ( MonadRandom )
|
||||||
|
import Data.Singletons ( SingI )
|
||||||
|
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Core.Network
|
||||||
|
|
||||||
|
|
||||||
|
-- | Witness type to say indicate we're building up with a normal feed
|
||||||
|
-- forward layer.
|
||||||
|
data FeedForward :: * -> *
|
||||||
|
-- | Witness type to say indicate we're building up with a recurrent layer.
|
||||||
|
data Recurrent :: * -> *
|
||||||
|
|
||||||
|
-- | Class for a recurrent layer.
|
||||||
|
-- It's quite similar to a normal layer but for the input and output
|
||||||
|
-- of an extra recurrent data shape.
|
||||||
|
class UpdateLayer x => RecurrentUpdateLayer x where
|
||||||
|
-- | Shape of data that is passed between each subsequent run of the layer
|
||||||
|
type RecurrentShape x :: Shape
|
||||||
|
|
||||||
|
class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
|
||||||
|
-- | Used in training and scoring. Take the input from the previous
|
||||||
|
-- layer, and give the output from this layer.
|
||||||
|
runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (S (RecurrentShape x), S o)
|
||||||
|
-- | Back propagate a step. Takes the current layer, the input that the
|
||||||
|
-- layer gave from the input and the back propagated derivatives from
|
||||||
|
-- the layer above.
|
||||||
|
-- Returns the gradient layer and the derivatives to push back further.
|
||||||
|
runRecurrentBackwards :: x -> S (RecurrentShape x) -> S i -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i)
|
||||||
|
|
||||||
|
data RecurrentNetwork :: [*] -> [Shape] -> * where
|
||||||
|
OR :: (SingI i, SingI o, Layer x i o) => !x -> RecurrentNetwork '[FeedForward x] '[i, o]
|
||||||
|
(:~~>) :: (SingI i, Layer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (FeedForward x ': xs) (i ': h ': hs)
|
||||||
|
(:~@>) :: (SingI i, RecurrentLayer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (Recurrent x ': xs) (i ': h ': hs)
|
||||||
|
infixr 5 :~~>
|
||||||
|
infixr 5 :~@>
|
||||||
|
|
||||||
|
instance Show (RecurrentNetwork l h) where
|
||||||
|
show (OR a) = "OR " ++ show a
|
||||||
|
show (i :~~> o) = show i ++ "\n:~~>\n" ++ show o
|
||||||
|
show (i :~@> o) = show i ++ "\n:~@>\n" ++ show o
|
||||||
|
|
||||||
|
|
||||||
|
-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
|
||||||
|
-- Parameterised on the layers of a Network.
|
||||||
|
data RecurrentInputs :: [*] -> * where
|
||||||
|
ORS :: UpdateLayer x
|
||||||
|
=> () -> RecurrentInputs '[FeedForward x]
|
||||||
|
(:~~+>) :: UpdateLayer x
|
||||||
|
=> () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs)
|
||||||
|
(:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x)
|
||||||
|
=> !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs)
|
||||||
|
infixr 5 :~~+>
|
||||||
|
infixr 5 :~@+>
|
||||||
|
|
||||||
|
-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
|
||||||
|
-- recurrent network and a set of random inputs for it is with the randomRecurrent.
|
||||||
|
class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where
|
||||||
|
-- | Create a network of the types requested
|
||||||
|
randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss, RecurrentInputs xs)
|
||||||
|
|
||||||
|
instance (SingI i, SingI o, Layer x i o) => CreatableRecurrent (FeedForward x ': '[]) (i ': o ': '[]) where
|
||||||
|
randomRecurrent = do
|
||||||
|
thisLayer <- createRandom
|
||||||
|
return (OR thisLayer, ORS ())
|
||||||
|
|
||||||
|
instance (SingI i, Layer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (FeedForward x ': xs) (i ': o ': r ': rs) where
|
||||||
|
randomRecurrent = do
|
||||||
|
thisLayer <- createRandom
|
||||||
|
(rest, resti) <- randomRecurrent
|
||||||
|
return (thisLayer :~~> rest, () :~~+> resti)
|
||||||
|
|
||||||
|
instance (SingI i, RecurrentLayer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (Recurrent x ': xs) (i ': o ': r ': rs) where
|
||||||
|
randomRecurrent = do
|
||||||
|
thisLayer <- createRandom
|
||||||
|
thisShape <- randomOfShape
|
||||||
|
(rest, resti) <- randomRecurrent
|
||||||
|
return (thisLayer :~@> rest, thisShape :~@+> resti)
|
||||||
|
|
144
src/Grenade/Recurrent/Core/Runner.hs
Normal file
144
src/Grenade/Recurrent/Core/Runner.hs
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
module Grenade.Recurrent.Core.Runner (
|
||||||
|
trainRecurrent
|
||||||
|
, runRecurrent
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.Singletons.Prelude
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
import Grenade.Recurrent.Core.Network
|
||||||
|
|
||||||
|
-- | Drive and network and collect its back propogated gradients.
|
||||||
|
trainRecurrent :: forall shapes layers. SingI (Last shapes)
|
||||||
|
=> LearningParameters
|
||||||
|
-> RecurrentNetwork layers shapes
|
||||||
|
-> RecurrentInputs layers
|
||||||
|
-> [(S (Head shapes), Maybe (S (Last shapes)))]
|
||||||
|
-> (RecurrentNetwork layers shapes, RecurrentInputs layers)
|
||||||
|
trainRecurrent rate network recinputs examples =
|
||||||
|
updateBack $ go inputs network recinputs
|
||||||
|
where
|
||||||
|
inputs = fst <$> examples
|
||||||
|
targets = snd <$> examples
|
||||||
|
updateBack (a,recgrad,_) = (a,updateRecInputs rate recinputs recgrad)
|
||||||
|
|
||||||
|
go :: forall js sublayers. (Last js ~ Last shapes)
|
||||||
|
=> [S (Head js)] -- ^ input vector
|
||||||
|
-> RecurrentNetwork sublayers js -- ^ network to train
|
||||||
|
-> RecurrentInputs sublayers
|
||||||
|
-> (RecurrentNetwork sublayers js, RecurrentInputs sublayers, [S (Head js)])
|
||||||
|
|
||||||
|
-- This is a simple non-recurrent layer, just map it forwards
|
||||||
|
-- Note we're doing training here, we could just return a list of gradients
|
||||||
|
-- (and probably will in future).
|
||||||
|
go !xs (layer :~~> n) (() :~~+> nIn)
|
||||||
|
= let ys = runForwards layer <$> xs
|
||||||
|
-- recursively run the rest of the network, and get the gradients from above.
|
||||||
|
(newFN, ig, grads) = go ys n nIn
|
||||||
|
-- calculate the gradient for this layer to pass down,
|
||||||
|
back = uncurry (runBackwards layer) <$> zip (reverse xs) grads
|
||||||
|
-- the new trained layer.
|
||||||
|
newlayer = runUpdates rate layer (fst <$> back)
|
||||||
|
|
||||||
|
in (newlayer :~~> newFN, () :~~+> ig, snd <$> back)
|
||||||
|
|
||||||
|
-- This is a recurrent layer, so we need to do a scan, first input to last, providing
|
||||||
|
-- the recurrent shape output to the next layer.
|
||||||
|
go !xs (layer :~@> n) (g :~@+> nIn)
|
||||||
|
= let ys = scanlFrom layer g xs
|
||||||
|
|
||||||
|
(newFN, ig, grads) = go (snd <$> ys) n nIn
|
||||||
|
|
||||||
|
backExamples = zip3 (fst <$> reverse ys) (reverse xs) grads
|
||||||
|
|
||||||
|
(rg, back) = myscanbackward layer backExamples
|
||||||
|
-- the new trained layer.
|
||||||
|
newlayer = runUpdates rate layer (fst <$> back)
|
||||||
|
in (newlayer :~@> newFN, rg :~@+> ig, snd <$> back)
|
||||||
|
|
||||||
|
-- Handle the output layer, bouncing the derivatives back down.
|
||||||
|
-- We may not have a target for each example, so when we don't use 0 gradient.
|
||||||
|
go !xs (OR layer) (ORS ())
|
||||||
|
= let ys = runForwards layer <$> xs
|
||||||
|
-- recursively run the rest of the network, and get the gradients from above.
|
||||||
|
back = uncurry (runBackwards layer) <$> zip xs (zipWith makeError ys targets)
|
||||||
|
-- the new trained layer.
|
||||||
|
newlayer = runUpdates rate layer (reverse $ fst <$> back)
|
||||||
|
in (OR newlayer, ORS (), reverse (snd <$> back))
|
||||||
|
|
||||||
|
go _ _ _ =
|
||||||
|
error "Impossible for network and recurrent inputs to have different shapes"
|
||||||
|
|
||||||
|
|
||||||
|
makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes)
|
||||||
|
makeError _ Nothing = 0
|
||||||
|
makeError y (Just t) = y - t
|
||||||
|
|
||||||
|
updateRecInputs :: forall sublayers.
|
||||||
|
LearningParameters
|
||||||
|
-> RecurrentInputs sublayers
|
||||||
|
-> RecurrentInputs sublayers
|
||||||
|
-> RecurrentInputs sublayers
|
||||||
|
|
||||||
|
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
|
||||||
|
= () :~~+> updateRecInputs l xs ys
|
||||||
|
|
||||||
|
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
|
||||||
|
= (x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
|
||||||
|
|
||||||
|
updateRecInputs _ (ORS ()) (ORS ())
|
||||||
|
= ORS ()
|
||||||
|
updateRecInputs _ _ _
|
||||||
|
= error "Impossible for updateRecInputs to have different shapes"
|
||||||
|
|
||||||
|
scanlFrom :: forall x i o. RecurrentLayer x i o
|
||||||
|
=> x -- ^ the layer
|
||||||
|
-> S (RecurrentShape x) -- ^ place to start
|
||||||
|
-> [S i] -- ^ list of inputs to scan through
|
||||||
|
-> [(S (RecurrentShape x), S o)] -- ^ list of scan inputs and outputs
|
||||||
|
scanlFrom !layer !recShape (x:xs) =
|
||||||
|
let (lerec, lepush) = runRecurrentForwards layer recShape x
|
||||||
|
in (recShape, lepush) : scanlFrom layer lerec xs
|
||||||
|
scanlFrom _ _ [] = []
|
||||||
|
|
||||||
|
myscanbackward :: forall x i o. RecurrentLayer x i o
|
||||||
|
=> x -- ^ the layer
|
||||||
|
-> [(S (RecurrentShape x), S i, S o)] -- ^ the list of inputs and output to scan over
|
||||||
|
-> (S (RecurrentShape x), [(Gradient x, S i)]) -- ^ list of gradients to fold and inputs to backprop
|
||||||
|
myscanbackward layer =
|
||||||
|
goX 0
|
||||||
|
where
|
||||||
|
goX :: S (RecurrentShape x) -> [(S (RecurrentShape x), S i, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)])
|
||||||
|
goX !lastback ((recShape, lastin, backgrad):xs) =
|
||||||
|
let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recShape lastin lastback backgrad
|
||||||
|
(pushedback, ll) = goX recgrad xs
|
||||||
|
in (pushedback, (layergrad, ingrad) : ll)
|
||||||
|
goX !lastback [] = (lastback, [])
|
||||||
|
|
||||||
|
-- | Just forwards propagation with no training.
|
||||||
|
runRecurrent :: RecurrentNetwork layers shapes
|
||||||
|
-> RecurrentInputs layers -> S (Head shapes)
|
||||||
|
-> (RecurrentInputs layers, S (Last shapes))
|
||||||
|
runRecurrent (layer :~~> n) (() :~~+> nr) !x
|
||||||
|
= let ys = runForwards layer x
|
||||||
|
(nr', o) = runRecurrent n nr ys
|
||||||
|
in (() :~~+> nr', o)
|
||||||
|
runRecurrent (layer :~@> n) (recin :~@+> nr) !x
|
||||||
|
= let (recin', y) = runRecurrentForwards layer recin x
|
||||||
|
(nr', o) = runRecurrent n nr y
|
||||||
|
in (recin' :~@+> nr', o)
|
||||||
|
runRecurrent (OR layer) (ORS ()) !x
|
||||||
|
= (ORS (), runForwards layer x)
|
||||||
|
|
||||||
|
runRecurrent _ _ _
|
||||||
|
= error "Impossible for the gradients of a network to have a different length or shape to the network"
|
92
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
Normal file
92
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
|
module Grenade.Recurrent.Layers.BasicRecurrent (
|
||||||
|
BasicRecurrent (..)
|
||||||
|
, randomBasicRecurrent
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Random ( MonadRandom, getRandom )
|
||||||
|
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
import Grenade.Recurrent.Core.Network
|
||||||
|
|
||||||
|
data BasicRecurrent :: Nat -- Input layer size
|
||||||
|
-> Nat -- Output layer size
|
||||||
|
-> * where
|
||||||
|
BasicRecurrent :: ( KnownNat input
|
||||||
|
, KnownNat output
|
||||||
|
, KnownNat matrixCols
|
||||||
|
, matrixCols ~ (input + output))
|
||||||
|
=> !(R output) -- Bias neuron weights
|
||||||
|
-> !(R output) -- Bias neuron momentum
|
||||||
|
-> !(L output matrixCols) -- Activation
|
||||||
|
-> !(L output matrixCols) -- Momentum
|
||||||
|
-> BasicRecurrent input output
|
||||||
|
|
||||||
|
data BasicRecurrent' :: Nat -- Input layer size
|
||||||
|
-> Nat -- Output layer size
|
||||||
|
-> * where
|
||||||
|
BasicRecurrent' :: ( KnownNat input
|
||||||
|
, KnownNat output
|
||||||
|
, KnownNat matrixCols
|
||||||
|
, matrixCols ~ (input + output))
|
||||||
|
=> !(R output) -- Bias neuron gradients
|
||||||
|
-> !(L output matrixCols)
|
||||||
|
-> BasicRecurrent' input output
|
||||||
|
|
||||||
|
instance Show (BasicRecurrent i o) where
|
||||||
|
show BasicRecurrent {} = "BasicRecurrent"
|
||||||
|
|
||||||
|
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
|
||||||
|
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
|
||||||
|
|
||||||
|
runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
|
||||||
|
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
|
||||||
|
newBias = oldBias + newBiasMomentum
|
||||||
|
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
|
||||||
|
regulariser = konst (learningRegulariser * learningRate) * oldActivations
|
||||||
|
newActivations = oldActivations + newMomentum - regulariser
|
||||||
|
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum
|
||||||
|
|
||||||
|
createRandom = randomBasicRecurrent
|
||||||
|
|
||||||
|
instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentUpdateLayer (BasicRecurrent i o) where
|
||||||
|
type RecurrentShape (BasicRecurrent i o) = 'D1 o
|
||||||
|
|
||||||
|
instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentLayer (BasicRecurrent i o) ('D1 i) ('D1 o) where
|
||||||
|
-- Do a matrix vector multiplication and return the result.
|
||||||
|
runRecurrentForwards (BasicRecurrent wB _ wN _) (S1D lastOutput) (S1D thisInput) =
|
||||||
|
let thisOutput = S1D $ wB + wN #> (thisInput # lastOutput)
|
||||||
|
in (thisOutput, thisOutput)
|
||||||
|
|
||||||
|
-- Run a backpropogation step for a full connected layer.
|
||||||
|
runRecurrentBackwards (BasicRecurrent _ _ wN _) (S1D lastOutput) (S1D thisInput) (S1D dRec) (S1D dEdy) =
|
||||||
|
let biasGradient = (dRec + dEdy)
|
||||||
|
layerGrad = (dRec + dEdy) `outer` (thisInput # lastOutput)
|
||||||
|
-- calcluate derivatives for next step
|
||||||
|
(backGrad, recGrad) = split $ tr wN #> (dRec + dEdy)
|
||||||
|
in (BasicRecurrent' biasGradient layerGrad, S1D recGrad, S1D backGrad)
|
||||||
|
|
||||||
|
randomBasicRecurrent :: (MonadRandom m, KnownNat i, KnownNat o, KnownNat x, x ~ (i + o))
|
||||||
|
=> m (BasicRecurrent i o)
|
||||||
|
randomBasicRecurrent = do
|
||||||
|
seed1 <- getRandom
|
||||||
|
seed2 <- getRandom
|
||||||
|
let wB = randomVector seed1 Uniform * 2 - 1
|
||||||
|
wN = uniformSample seed2 (-1) 1
|
||||||
|
bm = konst 0
|
||||||
|
mm = konst 0
|
||||||
|
return $ BasicRecurrent wB bm wN mm
|
244
src/Grenade/Recurrent/Layers/LSTM.hs
Normal file
244
src/Grenade/Recurrent/Layers/LSTM.hs
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE ViewPatterns #-}
|
||||||
|
module Grenade.Recurrent.Layers.LSTM (
|
||||||
|
LSTM (..)
|
||||||
|
, LSTMWeights (..)
|
||||||
|
, randomLSTM
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Control.Monad.Random ( MonadRandom, getRandom )
|
||||||
|
|
||||||
|
-- import Data.List ( foldl1' )
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
|
import Grenade.Core.Network
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
import Grenade.Layers.Internal.Update
|
||||||
|
|
||||||
|
import Grenade.Recurrent.Core.Network
|
||||||
|
|
||||||
|
-- | Long Short Term Memory Recurrent unit
|
||||||
|
--
|
||||||
|
-- This is a Peephole formulation, so the recurrent shape is
|
||||||
|
-- just the cell state, the previous output is not held or used
|
||||||
|
-- at all.
|
||||||
|
data LSTM :: Nat -> Nat -> * where
|
||||||
|
LSTM :: ( KnownNat input
|
||||||
|
, KnownNat output
|
||||||
|
) => !(LSTMWeights input output) -- Weights
|
||||||
|
-> !(LSTMWeights input output) -- Momentums
|
||||||
|
-> LSTM input output
|
||||||
|
|
||||||
|
data LSTMWeights :: Nat -> Nat -> * where
|
||||||
|
LSTMWeights :: ( KnownNat input
|
||||||
|
, KnownNat output
|
||||||
|
) => {
|
||||||
|
lstmWf :: !(L output input) -- Weight Forget (W_f)
|
||||||
|
, lstmUf :: !(L output output) -- Cell State Forget (U_f)
|
||||||
|
, lstmBf :: !(R output) -- Bias Forget (b_f)
|
||||||
|
, lstmWi :: !(L output input) -- Weight Input (W_i)
|
||||||
|
, lstmUi :: !(L output output) -- Cell State Input (U_i)
|
||||||
|
, lstmBi :: !(R output) -- Bias Input (b_i)
|
||||||
|
, lstmWo :: !(L output input) -- Weight Output (W_o)
|
||||||
|
, lstmUo :: !(L output output) -- Cell State Output (U_o)
|
||||||
|
, lstmBo :: !(R output) -- Bias Output (b_o)
|
||||||
|
, lstmWc :: !(L output input) -- Weight Cell (W_c)
|
||||||
|
, lstmBc :: !(R output) -- Bias Cell (b_c)
|
||||||
|
} -> LSTMWeights input output
|
||||||
|
|
||||||
|
instance Show (LSTM i o) where
|
||||||
|
show LSTM {} = "LSTM"
|
||||||
|
|
||||||
|
instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
|
||||||
|
-- The gradients are the same shape as the weights and momentum
|
||||||
|
-- This seems to be a general pattern, maybe it should be enforced.
|
||||||
|
type Gradient (LSTM i o) = (LSTMWeights i o)
|
||||||
|
|
||||||
|
-- Run the update function for each group matrix/vector of weights, momentums and gradients.
|
||||||
|
-- Hmm, maybe the function should be used instead of passing in the learning parameters.
|
||||||
|
runUpdate LearningParameters {..} (LSTM w m) g =
|
||||||
|
let (wf, wf') = u lstmWf w m g
|
||||||
|
(uf, uf') = u lstmUf w m g
|
||||||
|
(bf, bf') = v lstmBf w m g
|
||||||
|
(wi, wi') = u lstmWi w m g
|
||||||
|
(ui, ui') = u lstmUi w m g
|
||||||
|
(bi, bi') = v lstmBi w m g
|
||||||
|
(wo, wo') = u lstmWo w m g
|
||||||
|
(uo, uo') = u lstmUo w m g
|
||||||
|
(bo, bo') = v lstmBo w m g
|
||||||
|
(wc, wc') = u lstmWc w m g
|
||||||
|
(bc, bc') = v lstmBc w m g
|
||||||
|
in LSTM (LSTMWeights wf uf bf wi ui bi wo uo bo wc bc) (LSTMWeights wf' uf' bf' wi' ui' bi' wo' uo' bo' wc' bc')
|
||||||
|
where
|
||||||
|
-- Utility function for updating with the momentum, gradients, and weights.
|
||||||
|
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
|
||||||
|
u e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||||
|
decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
|
||||||
|
|
||||||
|
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
|
||||||
|
v e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||||
|
decendVector learningRate learningMomentum learningRegulariser weights gradient momentum
|
||||||
|
|
||||||
|
-- There's a lot of updates here, so to try and minimise the number of data copies
|
||||||
|
-- we'll create a mutable bucket for each.
|
||||||
|
-- runUpdates rate lstm gs =
|
||||||
|
-- let combinedGradient = foldl1' uu gs
|
||||||
|
-- in runUpdate rate lstm combinedGradient
|
||||||
|
-- where
|
||||||
|
-- uu :: (KnownNat i, KnownNat o) => LSTMWeights i o -> LSTMWeights i o -> LSTMWeights i o
|
||||||
|
-- uu a b =
|
||||||
|
-- let wf = u lstmWf a b
|
||||||
|
-- uf = u lstmUf a b
|
||||||
|
-- bf = v lstmBf a b
|
||||||
|
-- wi = u lstmWi a b
|
||||||
|
-- ui = u lstmUi a b
|
||||||
|
-- bi = v lstmBi a b
|
||||||
|
-- wo = u lstmWo a b
|
||||||
|
-- uo = u lstmUo a b
|
||||||
|
-- bo = v lstmBo a b
|
||||||
|
-- wc = u lstmWc a b
|
||||||
|
-- bc = v lstmBc a b
|
||||||
|
-- in LSTMWeights wf uf bf wi ui bi wo uo bo wc bc
|
||||||
|
-- u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> L out ix
|
||||||
|
-- u e (e -> a) (e -> b) = tr $ tr a + tr b
|
||||||
|
|
||||||
|
-- v :: forall x ix. (x -> (R ix)) -> x -> x -> R ix
|
||||||
|
-- v e (e -> a) (e -> b) = a + b
|
||||||
|
|
||||||
|
createRandom = randomLSTM
|
||||||
|
|
||||||
|
instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
|
||||||
|
-- The recurrent shape is the same size as the output.
|
||||||
|
-- It's actually the cell state however, as this is a peephole variety LSTM.
|
||||||
|
type RecurrentShape (LSTM i o) = 'D1 o
|
||||||
|
|
||||||
|
instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
|
||||||
|
-- Forward propagation for the LSTM layer.
|
||||||
|
-- The size of the cell state is also the size of the output.
|
||||||
|
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
|
||||||
|
let -- Forget state vector
|
||||||
|
f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
|
||||||
|
-- Input state vector
|
||||||
|
i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
|
||||||
|
-- Output state vector
|
||||||
|
o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
|
||||||
|
-- Cell input state vector
|
||||||
|
c_x = tanh $ lstmBc + lstmWc #> input
|
||||||
|
-- Cell state
|
||||||
|
c_t = f_t * cell + i_t * c_x
|
||||||
|
-- Output (it's sometimes recommended to use tanh c_t)
|
||||||
|
h_t = o_t * c_t
|
||||||
|
in (S1D c_t, S1D h_t)
|
||||||
|
|
||||||
|
-- Run a backpropogation step for an LSTM layer.
|
||||||
|
-- We're doing all the derivatives by hand here, so one should
|
||||||
|
-- be extra careful when changing this.
|
||||||
|
--
|
||||||
|
-- There's a test version using the AD library without hmatrix in the test
|
||||||
|
-- suite. These should match always.
|
||||||
|
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) (S1D cellGrad) (S1D h_t') =
|
||||||
|
-- We're not keeping the Wengert tape during the forward pass,
|
||||||
|
-- so we're duplicating some work here.
|
||||||
|
--
|
||||||
|
-- If I was being generous, I'd call it checkpointing.
|
||||||
|
--
|
||||||
|
-- Maybe think about better ways to store some intermediate states.
|
||||||
|
let -- Forget state vector
|
||||||
|
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
|
||||||
|
f_t = sigmoid f_s
|
||||||
|
-- Input state vector
|
||||||
|
i_s = lstmBi + lstmWi #> input + lstmUi #> cell
|
||||||
|
i_t = sigmoid i_s
|
||||||
|
-- Output state vector
|
||||||
|
o_s = lstmBo + lstmWo #> input + lstmUo #> cell
|
||||||
|
o_t = sigmoid o_s
|
||||||
|
-- Cell input state vector
|
||||||
|
c_s = lstmBc + lstmWc #> input
|
||||||
|
c_x = tanh c_s
|
||||||
|
-- Cell state
|
||||||
|
c_t = f_t * cell + i_t * c_x
|
||||||
|
|
||||||
|
-- Reverse Mode AD Derivitives
|
||||||
|
c_t' = h_t' * o_t + cellGrad
|
||||||
|
|
||||||
|
f_t' = c_t' * cell
|
||||||
|
f_s' = sigmoid' f_s * f_t'
|
||||||
|
|
||||||
|
o_t' = h_t' * c_t
|
||||||
|
o_s' = sigmoid' o_s * o_t'
|
||||||
|
|
||||||
|
i_t' = c_t' * c_x
|
||||||
|
i_s' = sigmoid' i_s * i_t'
|
||||||
|
|
||||||
|
c_x' = c_t' * i_t
|
||||||
|
c_s' = tanh' c_s * c_x'
|
||||||
|
|
||||||
|
-- The derivatives to pass sideways (recurrent) and downwards
|
||||||
|
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
|
||||||
|
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'
|
||||||
|
|
||||||
|
-- Calculate the gradient Matricies for the input
|
||||||
|
lstmWf' = f_s' `outer` input
|
||||||
|
lstmWi' = i_s' `outer` input
|
||||||
|
lstmWo' = o_s' `outer` input
|
||||||
|
lstmWc' = c_s' `outer` input
|
||||||
|
|
||||||
|
-- Calculate the gradient Matricies for the cell
|
||||||
|
lstmUf' = f_s' `outer` cell
|
||||||
|
lstmUi' = i_s' `outer` cell
|
||||||
|
lstmUo' = o_s' `outer` cell
|
||||||
|
|
||||||
|
-- The biases just get the values, but we'll write it so it's obvious
|
||||||
|
lstmBf' = f_s'
|
||||||
|
lstmBi' = i_s'
|
||||||
|
lstmBo' = o_s'
|
||||||
|
lstmBc' = c_s'
|
||||||
|
|
||||||
|
gradients = LSTMWeights lstmWf' lstmUf' lstmBf' lstmWi' lstmUi' lstmBi' lstmWo' lstmUo' lstmBo' lstmWc' lstmBc'
|
||||||
|
in (gradients, S1D cell', S1D input')
|
||||||
|
|
||||||
|
-- | Generate an LSTM layer with random Weights
|
||||||
|
-- one can also just call createRandom from UpdateLayer
|
||||||
|
--
|
||||||
|
-- Has forget gate biases set to 1 to encourage early learning.
|
||||||
|
--
|
||||||
|
-- https://github.com/karpathy/char-rnn/commit/0dfeaa454e687dd0278f036552ea1e48a0a408c9
|
||||||
|
--
|
||||||
|
randomLSTM :: forall m i o. (MonadRandom m, KnownNat i, KnownNat o)
|
||||||
|
=> m (LSTM i o)
|
||||||
|
randomLSTM = do
|
||||||
|
let w = (\s -> uniformSample s (-1) 1 ) <$> getRandom
|
||||||
|
u = (\s -> uniformSample s (-1) 1 ) <$> getRandom
|
||||||
|
v = (\s -> randomVector s Uniform * 2 - 1) <$> getRandom
|
||||||
|
|
||||||
|
w0 = konst 0
|
||||||
|
u0 = konst 0
|
||||||
|
v0 = konst 0
|
||||||
|
|
||||||
|
LSTM <$> (LSTMWeights <$> w <*> u <*> pure (konst 1) <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
|
||||||
|
<*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
||||||
|
|
||||||
|
-- | Maths
|
||||||
|
--
|
||||||
|
-- TODO: move to not here
|
||||||
|
sigmoid :: Floating a => a -> a
|
||||||
|
sigmoid x = 1 / (1 + exp (-x))
|
||||||
|
|
||||||
|
sigmoid' :: Floating a => a -> a
|
||||||
|
sigmoid' x = logix * (1 - logix)
|
||||||
|
where
|
||||||
|
logix = sigmoid x
|
||||||
|
|
||||||
|
tanh' :: (Floating a) => a -> a
|
||||||
|
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
23
src/Grenade/Recurrent/Layers/Trivial.hs
Normal file
23
src/Grenade/Recurrent/Layers/Trivial.hs
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
|
module Grenade.Recurrent.Layers.Trivial (
|
||||||
|
Trivial (..)
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Grenade.Core.Network
|
||||||
|
|
||||||
|
-- | A trivial layer.
|
||||||
|
data Trivial = Trivial
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
instance UpdateLayer Trivial where
|
||||||
|
type Gradient Trivial = ()
|
||||||
|
runUpdate _ _ _ = Trivial
|
||||||
|
createRandom = return Trivial
|
||||||
|
|
||||||
|
instance (a ~ b) => Layer Trivial a b where
|
||||||
|
runForwards _ = id
|
||||||
|
runBackwards _ _ y = ((), y)
|
84
src/Grenade/Utils/OneHot.hs
Normal file
84
src/Grenade/Utils/OneHot.hs
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
|
module Grenade.Utils.OneHot (
|
||||||
|
oneHot
|
||||||
|
, hotMap
|
||||||
|
, makeHot
|
||||||
|
, unHot
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.List ( group, sort )
|
||||||
|
|
||||||
|
import Data.Map ( Map )
|
||||||
|
import qualified Data.Map as M
|
||||||
|
|
||||||
|
import Data.Proxy
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
|
||||||
|
import Data.Vector ( Vector )
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra ( maxIndex )
|
||||||
|
import Numeric.LinearAlgebra.Devel
|
||||||
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
|
import Grenade.Core.Shape
|
||||||
|
|
||||||
|
-- | From an int which is hot, create a 1D Shape
|
||||||
|
-- with one index hot (1) with the rest 0.
|
||||||
|
-- Rerurns Nothing if the hot number is larger
|
||||||
|
-- than the length of the vector.
|
||||||
|
oneHot :: forall n. (KnownNat n)
|
||||||
|
=> Int -> Maybe (S ('D1 n))
|
||||||
|
oneHot hot =
|
||||||
|
let len = fromIntegral $ natVal (Proxy :: Proxy n)
|
||||||
|
in if hot < len
|
||||||
|
then
|
||||||
|
fmap S1D . create $ runSTVector $ do
|
||||||
|
vec <- newVector 0 len
|
||||||
|
writeVector vec hot 1
|
||||||
|
return vec
|
||||||
|
else Nothing
|
||||||
|
|
||||||
|
-- | Create a one hot map from any enumerable.
|
||||||
|
-- Returns a map, and the ordered list for the reverse transformation
|
||||||
|
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a)
|
||||||
|
hotMap n as =
|
||||||
|
let len = fromIntegral $ natVal n
|
||||||
|
uniq = [ c | (c:_) <- group $ sort as]
|
||||||
|
hotl = length uniq
|
||||||
|
in if hotl <= len
|
||||||
|
then
|
||||||
|
Just (M.fromList $ zip uniq [0..], V.fromList uniq)
|
||||||
|
else Nothing
|
||||||
|
|
||||||
|
-- | From a map and value, create a 1D Shape
|
||||||
|
-- with one index hot (1) with the rest 0.
|
||||||
|
-- Rerurns Nothing if the hot number is larger
|
||||||
|
-- than the length of the vector or the map
|
||||||
|
-- doesn't contain the value.
|
||||||
|
makeHot :: forall a n. (Ord a, KnownNat n)
|
||||||
|
=> Map a Int -> a -> Maybe (S ('D1 n))
|
||||||
|
makeHot m x = do
|
||||||
|
hot <- M.lookup x m
|
||||||
|
let len = fromIntegral $ natVal (Proxy :: Proxy n)
|
||||||
|
if hot < len
|
||||||
|
then
|
||||||
|
fmap S1D . create $ runSTVector $ do
|
||||||
|
vec <- newVector 0 len
|
||||||
|
writeVector vec hot 1
|
||||||
|
return vec
|
||||||
|
else Nothing
|
||||||
|
|
||||||
|
unHot :: forall a n. (KnownNat n)
|
||||||
|
=> Vector a -> (S ('D1 n)) -> Maybe a
|
||||||
|
unHot v (S1D xs)
|
||||||
|
= (V.!?) v
|
||||||
|
$ maxIndex (extract xs)
|
||||||
|
|
@ -1,11 +1,10 @@
|
|||||||
{-# LANGUAGE TemplateHaskell #-}
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE KindSignatures #-}
|
{-# LANGUAGE KindSignatures #-}
|
||||||
{-# LANGUAGE ConstraintKinds #-}
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
|
||||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||||
module Test.Grenade.Layers.Convolution where
|
module Test.Grenade.Layers.Convolution where
|
||||||
|
|
||||||
@ -30,6 +29,17 @@ data OpaqueConvolution :: * where
|
|||||||
instance Show OpaqueConvolution where
|
instance Show OpaqueConvolution where
|
||||||
show (OpaqueConvolution n) = show n
|
show (OpaqueConvolution n) = show n
|
||||||
|
|
||||||
|
genConvolution :: ( KnownNat channels
|
||||||
|
, KnownNat filters
|
||||||
|
, KnownNat kernelRows
|
||||||
|
, KnownNat kernelColumns
|
||||||
|
, KnownNat strideRows
|
||||||
|
, KnownNat strideColumns
|
||||||
|
, KnownNat kernelFlattened
|
||||||
|
, kernelFlattened ~ (kernelRows * kernelColumns * channels)
|
||||||
|
) => Jack (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||||
|
genConvolution = Convolution <$> uniformSample <*> uniformSample
|
||||||
|
|
||||||
genOpaqueOpaqueConvolution :: Jack OpaqueConvolution
|
genOpaqueOpaqueConvolution :: Jack OpaqueConvolution
|
||||||
genOpaqueOpaqueConvolution = do
|
genOpaqueOpaqueConvolution = do
|
||||||
Just channels <- someNatVal <$> choose (1, 10)
|
Just channels <- someNatVal <$> choose (1, 10)
|
||||||
@ -46,7 +56,7 @@ genOpaqueOpaqueConvolution = do
|
|||||||
p2 = natDict pkc
|
p2 = natDict pkc
|
||||||
p3 = natDict pch
|
p3 = natDict pch
|
||||||
in case p1 %* p2 %* p3 of
|
in case p1 %* p2 %* p3 of
|
||||||
Dict -> OpaqueConvolution <$> (Convolution <$> uniformSample <*> uniformSample :: Jack (Convolution ch fl kr kc sr sc))
|
Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc))
|
||||||
|
|
||||||
prop_conv_net_witness =
|
prop_conv_net_witness =
|
||||||
gamble genOpaqueOpaqueConvolution $ \onet ->
|
gamble genOpaqueOpaqueConvolution $ \onet ->
|
||||||
@ -80,9 +90,9 @@ prop_conv_net =
|
|||||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
|
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
|
||||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
|
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
|
||||||
(Dict, Dict, Dict, Dict) ->
|
(Dict, Dict, Dict, Dict) ->
|
||||||
gamble (S3D' <$> uniformSample) $ \(input :: S' ('D3 inRows inCols channels)) ->
|
gamble (S3D <$> uniformSample) $ \(input :: S ('D3 inRows inCols channels)) ->
|
||||||
let output :: S' ('D3 outRows outCols filters) = runForwards convLayer input
|
let output :: S ('D3 outRows outCols filters) = runForwards convLayer input
|
||||||
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S' ('D3 inRows inCols channels))
|
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels))
|
||||||
= runBackwards convLayer input output
|
= runBackwards convLayer input output
|
||||||
in backed `seq` True
|
in backed `seq` True
|
||||||
) :: Property
|
) :: Property
|
||||||
|
@ -44,10 +44,10 @@ genOpaqueFullyConnected = do
|
|||||||
prop_fully_connected_forwards :: Property
|
prop_fully_connected_forwards :: Property
|
||||||
prop_fully_connected_forwards =
|
prop_fully_connected_forwards =
|
||||||
gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) ->
|
gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) ->
|
||||||
gamble (S1D' <$> randomVector) $ \(input :: S' ('D1 i)) ->
|
gamble (S1D <$> randomVector) $ \(input :: S ('D1 i)) ->
|
||||||
let output :: S' ('D1 o) = runForwards fclayer input
|
let output :: S ('D1 o) = runForwards fclayer input
|
||||||
backed :: (Gradient (FullyConnected i o), S' ('D1 i))
|
backed :: (Gradient (FullyConnected i o), S ('D1 i))
|
||||||
= runBackwards fclayer input output
|
= runBackwards fclayer input output
|
||||||
in backed `seq` True
|
in backed `seq` True
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
{-# LANGUAGE TemplateHaskell #-}
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE KindSignatures #-}
|
{-# LANGUAGE KindSignatures #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||||
module Test.Grenade.Layers.Pooling where
|
module Test.Grenade.Layers.Pooling where
|
||||||
|
|
||||||
|
101
test/Test/Grenade/Recurrent/Layers/LSTM.hs
Normal file
101
test/Test/Grenade/Recurrent/Layers/LSTM.hs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
|
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||||
|
module Test.Grenade.Recurrent.Layers.LSTM where
|
||||||
|
|
||||||
|
import Disorder.Jack
|
||||||
|
|
||||||
|
import Data.Foldable ( toList )
|
||||||
|
import Data.Singletons.TypeLits
|
||||||
|
|
||||||
|
import Grenade
|
||||||
|
import Grenade.Recurrent
|
||||||
|
|
||||||
|
import qualified Numeric.LinearAlgebra as H
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as S
|
||||||
|
|
||||||
|
|
||||||
|
import qualified Test.Grenade.Recurrent.Layers.LSTM.Reference as Reference
|
||||||
|
import Test.Jack.Hmatrix
|
||||||
|
|
||||||
|
genLSTM :: forall i o. (KnownNat i, KnownNat o) => Jack (LSTM i o)
|
||||||
|
genLSTM = do
|
||||||
|
let w = uniformSample
|
||||||
|
u = uniformSample
|
||||||
|
v = randomVector
|
||||||
|
|
||||||
|
w0 = S.konst 0
|
||||||
|
u0 = S.konst 0
|
||||||
|
v0 = S.konst 0
|
||||||
|
|
||||||
|
LSTM <$> (LSTMWeights <$> w <*> u <*> v <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
|
||||||
|
<*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
||||||
|
|
||||||
|
prop_lstm_reference_forwards =
|
||||||
|
gamble randomVector $ \(input :: S.R 3) ->
|
||||||
|
gamble randomVector $ \(cell :: S.R 2) ->
|
||||||
|
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||||
|
let actual = runRecurrentForwards net (S1D cell) (S1D input)
|
||||||
|
in case actual of
|
||||||
|
((S1D cellOut) :: S ('D1 2), (S1D output) :: S ('D1 2)) ->
|
||||||
|
let cellOut' = Reference.Vector . H.toList . S.extract $ cellOut
|
||||||
|
output' = Reference.Vector . H.toList . S.extract $ output
|
||||||
|
refNet = Reference.lstmToReference lstmWeights
|
||||||
|
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||||
|
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||||
|
(refCO, refO) = Reference.runLSTM refNet refCell refInput
|
||||||
|
in toList refCO ~~~ toList cellOut' .&&. toList refO ~~~ toList output'
|
||||||
|
|
||||||
|
prop_lstm_reference_backwards =
|
||||||
|
gamble randomVector $ \(input :: S.R 3) ->
|
||||||
|
gamble randomVector $ \(cell :: S.R 2) ->
|
||||||
|
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||||
|
let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||||
|
in case actualBacks of
|
||||||
|
(actualGradients, _, _) ->
|
||||||
|
let refNet = Reference.lstmToReference lstmWeights
|
||||||
|
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||||
|
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||||
|
refGradients = Reference.runLSTMback refCell refInput refNet
|
||||||
|
in toList refGradients ~~~ toList (Reference.lstmToReference actualGradients)
|
||||||
|
|
||||||
|
prop_lstm_reference_backwards_input =
|
||||||
|
gamble randomVector $ \(input :: S.R 3) ->
|
||||||
|
gamble randomVector $ \(cell :: S.R 2) ->
|
||||||
|
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||||
|
let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||||
|
in case actualBacks of
|
||||||
|
(_, _, S1D actualGradients) ->
|
||||||
|
let refNet = Reference.lstmToReference lstmWeights
|
||||||
|
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||||
|
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||||
|
refGradients = Reference.runLSTMbackOnInput refCell refNet refInput
|
||||||
|
in toList refGradients ~~~ H.toList (S.extract actualGradients)
|
||||||
|
|
||||||
|
prop_lstm_reference_backwards_cell =
|
||||||
|
gamble randomVector $ \(input :: S.R 3) ->
|
||||||
|
gamble randomVector $ \(cell :: S.R 2) ->
|
||||||
|
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||||
|
let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||||
|
in case actualBacks of
|
||||||
|
(_, S1D actualGradients, _) ->
|
||||||
|
let refNet = Reference.lstmToReference lstmWeights
|
||||||
|
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||||
|
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||||
|
refGradients = Reference.runLSTMbackOnCell refInput refNet refCell
|
||||||
|
in toList refGradients ~~~ (H.toList . S.extract $ actualGradients)
|
||||||
|
|
||||||
|
|
||||||
|
(~~~) as bs = all (< 1e-8) (zipWith (-) as bs)
|
||||||
|
infix 4 ~~~
|
||||||
|
|
||||||
|
return []
|
||||||
|
tests :: IO Bool
|
||||||
|
tests = $quickCheckAll
|
149
test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs
Normal file
149
test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE ConstraintKinds #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE DeriveFunctor #-}
|
||||||
|
{-# LANGUAGE DeriveFoldable #-}
|
||||||
|
{-# LANGUAGE DeriveTraversable #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
|
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||||
|
module Test.Grenade.Recurrent.Layers.LSTM.Reference where
|
||||||
|
|
||||||
|
import Data.Reflection
|
||||||
|
import Numeric.AD.Mode.Reverse
|
||||||
|
import Numeric.AD.Internal.Reverse ( Tape )
|
||||||
|
|
||||||
|
import qualified Grenade.Recurrent.Layers.LSTM as LSTM
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as S
|
||||||
|
import qualified Numeric.LinearAlgebra as H
|
||||||
|
|
||||||
|
--
|
||||||
|
-- This module contains a set of list only versions of
|
||||||
|
-- an LSTM layer which can be used with the AD library.
|
||||||
|
--
|
||||||
|
-- Using this, we can check to make sure that our fast
|
||||||
|
-- back propagation implementation is correct.
|
||||||
|
--
|
||||||
|
|
||||||
|
-- | List only matrix deriving functor
|
||||||
|
data Matrix a = Matrix {
|
||||||
|
matrixWeights :: [[a]]
|
||||||
|
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||||
|
|
||||||
|
-- | List only vector deriving functor
|
||||||
|
data Vector a = Vector {
|
||||||
|
vectorWeights :: [a]
|
||||||
|
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||||
|
|
||||||
|
-- | List only LSTM weights
|
||||||
|
data RefLSTM a = RefLSTM
|
||||||
|
{ refLstmWf :: Matrix a -- Weight Forget (W_f)
|
||||||
|
, refLstmUf :: Matrix a -- Cell State Forget (U_f)
|
||||||
|
, refLstmBf :: Vector a -- Bias Forget (b_f)
|
||||||
|
, refLstmWi :: Matrix a -- Weight Input (W_i)
|
||||||
|
, refLstmUi :: Matrix a -- Cell State Input (U_i)
|
||||||
|
, refLstmBi :: Vector a -- Bias Input (b_i)
|
||||||
|
, refLstmWo :: Matrix a -- Weight Output (W_o)
|
||||||
|
, refLstmUo :: Matrix a -- Cell State Output (U_o)
|
||||||
|
, refLstmBo :: Vector a -- Bias Output (b_o)
|
||||||
|
, refLstmWc :: Matrix a -- Weight Cell (W_c)
|
||||||
|
, refLstmBc :: Vector a -- Bias Cell (b_c)
|
||||||
|
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||||
|
|
||||||
|
lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double
|
||||||
|
lstmToReference LSTM.LSTMWeights {..} =
|
||||||
|
let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f)
|
||||||
|
refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f)
|
||||||
|
refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f)
|
||||||
|
refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i)
|
||||||
|
refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i)
|
||||||
|
refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i)
|
||||||
|
refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o)
|
||||||
|
refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o)
|
||||||
|
refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o)
|
||||||
|
refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c)
|
||||||
|
refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c)
|
||||||
|
in RefLSTM {..}
|
||||||
|
|
||||||
|
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
|
||||||
|
runLSTM RefLSTM {..} cell input =
|
||||||
|
let -- Forget state vector
|
||||||
|
f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell
|
||||||
|
-- Input state vector
|
||||||
|
i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell
|
||||||
|
-- Output state vector
|
||||||
|
o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell
|
||||||
|
-- Cell input state vector
|
||||||
|
c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input
|
||||||
|
-- Cell state
|
||||||
|
c_t = f_t #* cell #+ i_t #* c_x
|
||||||
|
-- Output (it's sometimes recommended to use tanh c_t)
|
||||||
|
h_t = o_t #* c_t
|
||||||
|
in (c_t, h_t)
|
||||||
|
|
||||||
|
runLSTMback :: forall a. Floating a => Vector a -> Vector a -> RefLSTM a -> RefLSTM a
|
||||||
|
runLSTMback cell input =
|
||||||
|
grad f
|
||||||
|
where
|
||||||
|
f :: forall s. Reifies s Tape => RefLSTM (Reverse s a) -> Reverse s a
|
||||||
|
f net =
|
||||||
|
let cell' = fmap auto cell
|
||||||
|
input' = fmap auto input
|
||||||
|
(cells, forwarded) = runLSTM net cell' input'
|
||||||
|
in sum forwarded + sum cells
|
||||||
|
|
||||||
|
runLSTMbackOnInput :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a
|
||||||
|
runLSTMbackOnInput cell net =
|
||||||
|
grad f
|
||||||
|
where
|
||||||
|
f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a
|
||||||
|
f input =
|
||||||
|
let cell' = fmap auto cell
|
||||||
|
net' = fmap auto net
|
||||||
|
(cells, forwarded) = runLSTM net' cell' input
|
||||||
|
in sum forwarded + sum cells
|
||||||
|
|
||||||
|
runLSTMbackOnCell :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a
|
||||||
|
runLSTMbackOnCell input net =
|
||||||
|
grad f
|
||||||
|
where
|
||||||
|
f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a
|
||||||
|
f cell =
|
||||||
|
let input' = fmap auto input
|
||||||
|
net' = fmap auto net
|
||||||
|
(cells, forwarded) = runLSTM net' cell input'
|
||||||
|
in sum forwarded + sum cells
|
||||||
|
|
||||||
|
-- | Helper to multiply a matrix by a vector
|
||||||
|
matMult :: Num a => Matrix a -> Vector a -> Vector a
|
||||||
|
matMult (Matrix m) (Vector v) = Vector result
|
||||||
|
where
|
||||||
|
lrs = map length m
|
||||||
|
l = length v
|
||||||
|
result = if all (== l) lrs
|
||||||
|
then map (\r -> sum $ zipWith (*) r v) m
|
||||||
|
else error $ "Matrix has rows of length " ++ show lrs ++
|
||||||
|
" but vector is of length " ++ show l
|
||||||
|
|
||||||
|
(#>) :: Num a => Matrix a -> Vector a -> Vector a
|
||||||
|
(#>) = matMult
|
||||||
|
infixr 8 #>
|
||||||
|
|
||||||
|
(#+) :: Num a => Vector a -> Vector a -> Vector a
|
||||||
|
(#+) (Vector as) (Vector bs) = Vector $ zipWith (+) as bs
|
||||||
|
infixl 6 #+
|
||||||
|
|
||||||
|
(#-) :: Num a => Vector a -> Vector a -> Vector a
|
||||||
|
(#-) (Vector as) (Vector bs) = Vector $ zipWith (-) as bs
|
||||||
|
infixl 6 #-
|
||||||
|
|
||||||
|
(#*) :: Num a => Vector a -> Vector a -> Vector a
|
||||||
|
(#*) (Vector as) (Vector bs) = Vector $ zipWith (*) as bs
|
||||||
|
infixl 7 #*
|
||||||
|
|
||||||
|
sigmoid :: (Functor f, Floating a) => f a -> f a
|
||||||
|
sigmoid xs = (\x -> 1 / (1 + exp (-x))) <$> xs
|
@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
module Test.Jack.Hmatrix where
|
module Test.Jack.Hmatrix where
|
||||||
|
|
||||||
import Data.Proxy
|
|
||||||
import Disorder.Jack
|
import Disorder.Jack
|
||||||
|
|
||||||
import GHC.TypeLits
|
import GHC.TypeLits
|
||||||
@ -12,9 +11,7 @@ import GHC.TypeLits
|
|||||||
import qualified Numeric.LinearAlgebra.Static as HStatic
|
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||||
|
|
||||||
randomVector :: forall n. KnownNat n => Jack (HStatic.R n)
|
randomVector :: forall n. KnownNat n => Jack (HStatic.R n)
|
||||||
randomVector = HStatic.fromList <$> vectorOf (fromInteger (natVal (Proxy :: Proxy n))) sizedRealFrac
|
randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> sizedNat
|
||||||
|
|
||||||
uniformSample :: forall m n. (KnownNat m, KnownNat n) => Jack (HStatic.L m n)
|
uniformSample :: forall m n. (KnownNat m, KnownNat n) => Jack (HStatic.L m n)
|
||||||
uniformSample = HStatic.fromList
|
uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> sizedNat
|
||||||
<$> vectorOf (fromInteger (natVal (Proxy :: Proxy m)) * fromInteger (natVal (Proxy :: Proxy n)))
|
|
||||||
sizedRealFrac
|
|
||||||
|
13
test/test.hs
13
test/test.hs
@ -1,12 +1,13 @@
|
|||||||
import Disorder.Core.Main
|
import Disorder.Core.Main
|
||||||
|
|
||||||
import qualified Test.Grenade.Layers.Pooling as Test.Grenade.Layers.Pooling
|
import qualified Test.Grenade.Layers.Pooling
|
||||||
import qualified Test.Grenade.Layers.Convolution as Test.Grenade.Layers.Convolution
|
import qualified Test.Grenade.Layers.Convolution
|
||||||
import qualified Test.Grenade.Layers.FullyConnected as Test.Grenade.Layers.FullyConnected
|
import qualified Test.Grenade.Layers.FullyConnected
|
||||||
|
|
||||||
import qualified Test.Grenade.Layers.Internal.Convolution as Test.Grenade.Layers.Internal.Convolution
|
import qualified Test.Grenade.Layers.Internal.Convolution
|
||||||
import qualified Test.Grenade.Layers.Internal.Pooling as Test.Grenade.Layers.Internal.Pooling
|
import qualified Test.Grenade.Layers.Internal.Pooling
|
||||||
|
|
||||||
|
import qualified Test.Grenade.Recurrent.Layers.LSTM
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main =
|
main =
|
||||||
@ -17,4 +18,6 @@ main =
|
|||||||
|
|
||||||
, Test.Grenade.Layers.Internal.Convolution.tests
|
, Test.Grenade.Layers.Internal.Convolution.tests
|
||||||
, Test.Grenade.Layers.Internal.Pooling.tests
|
, Test.Grenade.Layers.Internal.Pooling.tests
|
||||||
|
|
||||||
|
, Test.Grenade.Recurrent.Layers.LSTM.tests
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user