mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-07-14 13:10:23 +03:00
Remove primes on shape instantiations
Add singletons for Shape and remove hacks on recurrent nets Add Recurrent Nets
This commit is contained in:
parent
88021fbde7
commit
ac0e4b22c8
@ -52,7 +52,7 @@ Usage
|
||||
To perform back propagation, one can call the eponymous function
|
||||
```haskell
|
||||
backPropagate :: forall input target shapes layers. (Head shapes ~ input, Last shapes ~ target)
|
||||
=> Network layers shapes -> S' input -> S' target -> Gradients layers
|
||||
=> Network layers shapes -> S input -> S target -> Gradients layers
|
||||
```
|
||||
which takes a network, appropriate input and target data, and returns the
|
||||
back propagated gradients for the network. The shapes of the gradients are
|
||||
|
13
cbits/gradient_decent.c
Normal file
13
cbits/gradient_decent.c
Normal file
@ -0,0 +1,13 @@
|
||||
#include "gradient_decent.h"
|
||||
|
||||
void decend_cpu(int len, double rate, double momentum, double regulariser,
|
||||
const double* weights,
|
||||
const double* gradient,
|
||||
const double* last,
|
||||
double* outputWeights, double* outputMomentum) {
|
||||
|
||||
for (int i = 0; i <= len; i++) {
|
||||
outputMomentum[i] = momentum * last[i] - rate * gradient[i];
|
||||
outputWeights[i] = weights[i] + outputMomentum[i] - (rate * regulariser) * weights[i];
|
||||
}
|
||||
}
|
9
cbits/gradient_decent.h
Normal file
9
cbits/gradient_decent.h
Normal file
@ -0,0 +1,9 @@
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
|
||||
void decend_cpu(int len, double rate, double momentum, double regulariser,
|
||||
const double* weights,
|
||||
const double* gradient,
|
||||
const double* last,
|
||||
double* outputWeights, double* outputMomentum);
|
||||
|
@ -1,11 +1,10 @@
|
||||
#include "im2col.h"
|
||||
|
||||
void im2col_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
void im2col_cpu(const double* data_im, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_col) {
|
||||
|
||||
data_im += dataOffset;
|
||||
const int channel_size = height * width;
|
||||
|
||||
for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
|
||||
@ -23,13 +22,12 @@ void im2col_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
}
|
||||
}
|
||||
|
||||
void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
||||
void col2im_cpu(const double* data_col, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_im) {
|
||||
|
||||
memset(data_im, 0, height * width * channels * sizeof(double));
|
||||
data_col += dataOffset;
|
||||
|
||||
const int channel_size = height * width;
|
||||
|
||||
@ -50,13 +48,11 @@ void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
||||
|
||||
inline double max ( double a, double b ) { return a > b ? a : b; }
|
||||
|
||||
void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
void pool_forwards_cpu(const double* data_im, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_pooled) {
|
||||
|
||||
data_im += dataOffset;
|
||||
|
||||
const int channel_size = height * width;
|
||||
|
||||
for (int channel = 0; channel < channels; channel++) {
|
||||
@ -89,14 +85,11 @@ void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels
|
||||
}
|
||||
}
|
||||
|
||||
void pool_backwards_cpu(const double* data_im, int data_im_offset,
|
||||
const double* data_pooled, int data_pooled_offset,
|
||||
void pool_backwards_cpu(const double* data_im, const double* data_pooled,
|
||||
const int channels, const int height, const int width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
double* data_backgrad ) {
|
||||
|
||||
data_im += data_im_offset;
|
||||
data_pooled += data_pooled_offset;
|
||||
memset(data_backgrad, 0, height * width * channels * sizeof(double));
|
||||
|
||||
const int channel_size = height * width;
|
||||
|
@ -2,23 +2,22 @@
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
void im2col_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
void im2col_cpu(const double* data_im, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_col);
|
||||
|
||||
void col2im_cpu(const double* data_col, int dataOffset, const int channels,
|
||||
void col2im_cpu(const double* data_col, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_im);
|
||||
|
||||
void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
|
||||
void pool_forwards_cpu(const double* data_im, const int channels,
|
||||
const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w,
|
||||
double* data_pooled);
|
||||
|
||||
void pool_backwards_cpu(const double* data_im, int data_im_offset,
|
||||
const double* data_pooled, int data_pooled_offset,
|
||||
void pool_backwards_cpu(const double* data_im, const double* data_pooled,
|
||||
const int channels, const int height, const int width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
double* data_backgrad );
|
||||
|
@ -19,6 +19,8 @@ library
|
||||
base >= 4.8 && < 5
|
||||
, bytestring == 0.10.*
|
||||
, async
|
||||
, containers
|
||||
, deepseq
|
||||
, either == 4.4.*
|
||||
, exceptions == 0.8.*
|
||||
, hmatrix
|
||||
@ -26,9 +28,11 @@ library
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, parallel == 3.2.*
|
||||
, primitive
|
||||
, reflection
|
||||
, text == 1.2.*
|
||||
, transformers
|
||||
, singletons
|
||||
, vector
|
||||
|
||||
ghc-options:
|
||||
-Wall
|
||||
@ -55,9 +59,23 @@ library
|
||||
|
||||
Grenade.Layers.Internal.Convolution
|
||||
Grenade.Layers.Internal.Pooling
|
||||
Grenade.Layers.Internal.Update
|
||||
|
||||
Grenade.Recurrent
|
||||
|
||||
Grenade.Recurrent.Core.Network
|
||||
Grenade.Recurrent.Core.Runner
|
||||
|
||||
Grenade.Recurrent.Layers.BasicRecurrent
|
||||
Grenade.Recurrent.Layers.LSTM
|
||||
Grenade.Recurrent.Layers.Trivial
|
||||
|
||||
Grenade.Utils.OneHot
|
||||
|
||||
includes: cbits/im2col.h
|
||||
cbits/gradient_decent.h
|
||||
c-sources: cbits/im2col.c
|
||||
cbits/gradient_decent.c
|
||||
|
||||
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
|
||||
|
||||
@ -90,6 +108,40 @@ executable mnist
|
||||
, transformers
|
||||
, singletons
|
||||
, MonadRandom
|
||||
, vector
|
||||
|
||||
executable recurrent
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/recurrent.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, either
|
||||
, optparse-applicative == 0.12.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, singletons
|
||||
, MonadRandom
|
||||
|
||||
|
||||
executable shakespeare
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/shakespeare.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, either
|
||||
, optparse-applicative == 0.12.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, singletons
|
||||
, vector
|
||||
, MonadRandom
|
||||
, containers
|
||||
|
||||
|
||||
test-suite test
|
||||
@ -117,6 +169,8 @@ test-suite test
|
||||
, quickcheck-instances == 0.3.*
|
||||
, MonadRandom
|
||||
, random
|
||||
, ad
|
||||
, reflection
|
||||
|
||||
|
||||
benchmark bench
|
||||
@ -135,3 +189,20 @@ benchmark bench
|
||||
, criterion == 1.1.*
|
||||
, grenade
|
||||
, hmatrix
|
||||
|
||||
benchmark bench-lstm
|
||||
type: exitcode-stdio-1.0
|
||||
|
||||
main-is: bench-lstm.hs
|
||||
|
||||
ghc-options: -Wall -threaded -O2
|
||||
|
||||
hs-source-dirs:
|
||||
bench
|
||||
|
||||
build-depends:
|
||||
base >= 3 && < 5
|
||||
, bytestring == 0.10.*
|
||||
, criterion == 1.1.*
|
||||
, grenade
|
||||
, hmatrix
|
||||
|
6
mafia
6
mafia
@ -1,5 +1,7 @@
|
||||
#!/bin/sh -eu
|
||||
|
||||
: ${MAFIA_HOME:=$HOME/.mafia}
|
||||
|
||||
fetch_latest () {
|
||||
if [ -z ${MAFIA_TEST_MODE+x} ]; then
|
||||
TZ=$(date +"%T")
|
||||
@ -55,7 +57,7 @@ exec_mafia () {
|
||||
# If we can't find the mafia version, then we need to upgrade the script.
|
||||
run_upgrade
|
||||
else
|
||||
MAFIA_BIN=$HOME/.ambiata/mafia/bin
|
||||
MAFIA_BIN=$MAFIA_HOME/bin
|
||||
MAFIA_FILE=mafia-$MAFIA_VERSION
|
||||
MAFIA_PATH=$MAFIA_BIN/$MAFIA_FILE
|
||||
|
||||
@ -118,4 +120,4 @@ case "$MODE" in
|
||||
upgrade) shift; run_upgrade "$@" ;;
|
||||
*) exec_mafia "$@"
|
||||
esac
|
||||
# Version: a1b39ee8ac1969ed2e891b9062d079be75863e99
|
||||
# Version: 3044e63eb472fb9e16926d4ab2ca9dd9e255829c
|
||||
|
@ -4,10 +4,10 @@
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Random
|
||||
import Data.List ( foldl' )
|
||||
|
||||
import GHC.TypeLits
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as SA
|
||||
@ -34,18 +34,18 @@ netTest :: MonadRandom m => LearningParameters -> Int -> m String
|
||||
netTest rate n = do
|
||||
inps <- replicateM n $ do
|
||||
s <- getRandom
|
||||
return $ S1D' $ SA.randomVector s SA.Uniform * 2 - 1
|
||||
let outs = flip map inps $ \(S1D' v) ->
|
||||
return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
|
||||
let outs = flip map inps $ \(S1D v) ->
|
||||
if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
|
||||
then S1D' $ fromRational 1
|
||||
else S1D' $ fromRational 0
|
||||
then S1D $ fromRational 1
|
||||
else S1D $ fromRational 0
|
||||
net0 <- randomNet
|
||||
|
||||
let trained = foldl trainEach net0 (zip inps outs)
|
||||
let trained = foldl' trainEach net0 (zip inps outs)
|
||||
let testIns = [ [ (x,y) | x <- [0..50] ]
|
||||
| y <- [0..20] ]
|
||||
|
||||
let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D' $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
|
||||
let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
|
||||
return $ unlines outMat
|
||||
|
||||
where
|
||||
@ -59,8 +59,8 @@ netTest rate n = do
|
||||
| n' <= 0.8 = '='
|
||||
| otherwise = '#'
|
||||
|
||||
normx :: S' ('D1 1) -> Double
|
||||
normx (S1D' r) = SA.mean r
|
||||
normx :: S ('D1 1) -> Double
|
||||
normx (S1D r) = SA.mean r
|
||||
|
||||
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
|
||||
|
||||
|
@ -5,23 +5,24 @@
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
|
||||
import Control.Applicative
|
||||
import Control.Monad
|
||||
import Control.Monad.Random
|
||||
import Control.Monad.Trans.Class
|
||||
import Control.Monad.Trans.Except
|
||||
|
||||
import qualified Data.Attoparsec.Text as A
|
||||
import Data.List ( foldl' )
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.Text.IO as T
|
||||
import qualified Data.Vector.Storable as V
|
||||
|
||||
import Numeric.LinearAlgebra (maxIndex)
|
||||
import Numeric.LinearAlgebra ( maxIndex )
|
||||
import qualified Numeric.LinearAlgebra.Static as SA
|
||||
|
||||
import Options.Applicative
|
||||
|
||||
import Grenade
|
||||
import Grenade.Utils.OneHot
|
||||
|
||||
-- The definition of our convolutional neural network.
|
||||
-- In the type signature, we have a type level list of shapes which are passed between the layers.
|
||||
@ -49,9 +50,9 @@ convTest iterations trainFile validateFile rate = do
|
||||
trainEach rate' !network (i, o) = train rate' network i o
|
||||
|
||||
runIteration trainRows validateRows net i = do
|
||||
let trained' = foldl (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
|
||||
let trained' = foldl' (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
|
||||
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
|
||||
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||
let res' = fmap (\(S1D label, S1D prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||
print trained'
|
||||
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
||||
return trained'
|
||||
@ -61,7 +62,7 @@ data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
|
||||
mnist' :: Parser MnistOpts
|
||||
mnist' = MnistOpts <$> argument str (metavar "TRAIN")
|
||||
<*> argument str (metavar "VALIDATE")
|
||||
<*> option auto (long "iterations" <> short 'i' <> value 10)
|
||||
<*> option auto (long "iterations" <> short 'i' <> value 15)
|
||||
<*> (LearningParameters
|
||||
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||
<*> option auto (long "momentum" <> value 0.9)
|
||||
@ -78,14 +79,14 @@ main = do
|
||||
Right () -> pure ()
|
||||
Left err -> putStrLn err
|
||||
|
||||
readMNIST :: FilePath -> ExceptT String IO [(S' ('D2 28 28), S' ('D1 10))]
|
||||
readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
|
||||
readMNIST mnist = ExceptT $ do
|
||||
mnistdata <- T.readFile mnist
|
||||
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
|
||||
|
||||
parseMNIST :: A.Parser (S' ('D2 28 28), S' ('D1 10))
|
||||
parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
|
||||
parseMNIST = do
|
||||
lab <- A.decimal
|
||||
pixels <- many (A.char ',' >> A.double)
|
||||
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
|
||||
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
|
||||
Just lab <- oneHot <$> A.decimal
|
||||
pixels <- many (A.char ',' >> A.double)
|
||||
image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
|
||||
return (image, lab)
|
||||
|
87
main/recurrent.hs
Normal file
87
main/recurrent.hs
Normal file
@ -0,0 +1,87 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
|
||||
import Control.Monad ( foldM )
|
||||
import Control.Monad.Random ( MonadRandom, getRandomR )
|
||||
|
||||
import Data.List ( cycle, unfoldr )
|
||||
import qualified Numeric.LinearAlgebra.Static as SA
|
||||
|
||||
import Options.Applicative
|
||||
|
||||
import Grenade
|
||||
import Grenade.Recurrent
|
||||
|
||||
-- The defininition for our simple recurrent network.
|
||||
-- This file just trains a network to generate a repeating sequence
|
||||
-- of 0 0 1.
|
||||
--
|
||||
-- The F and R types are Tagging types to ensure that the runner and
|
||||
-- creation function know how to treat the layers.
|
||||
type F = FeedForward
|
||||
type R = Recurrent
|
||||
|
||||
type RecNet = RecurrentNetwork '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial]
|
||||
'[ 'D1 1, 'D1 4, 'D1 1, 'D1 1 ]
|
||||
|
||||
type RecInput = RecurrentInputs '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial]
|
||||
|
||||
randomNet :: MonadRandom m => m (RecNet, RecInput)
|
||||
randomNet = randomRecurrent
|
||||
|
||||
netTest :: MonadRandom m => RecNet -> RecInput -> LearningParameters -> Int -> m (RecNet, RecInput)
|
||||
netTest net0 i0 rate iterations =
|
||||
foldM trainIteration (net0,i0) [1..iterations]
|
||||
where
|
||||
trainingCycle = cycle [c 0, c 0, c 1]
|
||||
|
||||
trainIteration (net, io) _ = do
|
||||
dropping <- getRandomR (0, 2)
|
||||
count <- getRandomR (5, 30)
|
||||
let t = drop dropping trainingCycle
|
||||
let example = ((,Nothing) <$> take count t) ++ [(t !! count, Just $ t !! (count + 1))]
|
||||
return $ trainEach net io example
|
||||
|
||||
trainEach !nt !io !ex = trainRecurrent rate nt io ex
|
||||
|
||||
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
|
||||
|
||||
feedForward' :: Parser FeedForwardOpts
|
||||
feedForward' = FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 20000)
|
||||
<*> (LearningParameters
|
||||
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||
<*> option auto (long "momentum" <> value 0.9)
|
||||
<*> option auto (long "l2" <> value 0.0005)
|
||||
)
|
||||
|
||||
generateRecurrent :: RecNet -> RecInput -> S ('D1 1) -> [Int]
|
||||
generateRecurrent n s i =
|
||||
unfoldr go (s, i)
|
||||
where
|
||||
go (x, y) =
|
||||
do let (ns, o) = runRecurrent n x y
|
||||
o' = heat o
|
||||
Just (o', (ns, fromIntegral o'))
|
||||
|
||||
heat :: S ('D1 1) -> Int
|
||||
heat x = case x of
|
||||
(S1D v) -> round (SA.mean v)
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm)
|
||||
putStrLn "Training network..."
|
||||
|
||||
(net0, i0) <- randomNet
|
||||
(trained, bestInput) <- netTest net0 i0 rate examples
|
||||
|
||||
let results = generateRecurrent trained bestInput (c 1)
|
||||
|
||||
print . take 50 . drop 100 $ results
|
||||
|
||||
c :: Double -> S ('D1 1)
|
||||
c = S1D . SA.konst
|
156
main/shakespeare.hs
Normal file
156
main/shakespeare.hs
Normal file
@ -0,0 +1,156 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TupleSections #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
import Control.Monad.Random
|
||||
import Control.Monad.Trans.Except
|
||||
|
||||
import Data.Char ( isUpper, toUpper, toLower )
|
||||
import Data.List ( unfoldr, foldl' )
|
||||
import Data.Maybe ( fromMaybe )
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import Data.Vector ( Vector )
|
||||
|
||||
import qualified Data.Map as M
|
||||
import Data.Proxy ( Proxy (..) )
|
||||
|
||||
|
||||
import Data.Singletons.Prelude
|
||||
import GHC.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra.Static ( konst )
|
||||
|
||||
import Options.Applicative
|
||||
|
||||
import Grenade
|
||||
import Grenade.Recurrent
|
||||
import Grenade.Utils.OneHot
|
||||
|
||||
-- The defininition for our natural language recurrent network.
|
||||
-- This network is able to learn and generate simple words in
|
||||
-- about an hour.
|
||||
--
|
||||
-- This is a first class recurrent net, although it's similar to
|
||||
-- an unrolled graph.
|
||||
--
|
||||
-- The F and R types are tagging types to ensure that the runner and
|
||||
-- creation function know how to treat the layers.
|
||||
--
|
||||
-- As an example, here's a short sequence generated.
|
||||
--
|
||||
-- > the see and and the sir, and and the make and the make and go the make and go the make and the
|
||||
--
|
||||
type F = FeedForward
|
||||
type R = Recurrent
|
||||
|
||||
-- The definition of our network
|
||||
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
|
||||
'[ 'D1 40, 'D1 40, 'D1 40, 'D1 40 ]
|
||||
|
||||
-- The definition of the "sideways" input, which the network if fed recurrently.
|
||||
type Shakespearian = RecurrentInputs '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
|
||||
|
||||
randomNet :: MonadRandom m => m (Shakespeare, Shakespearian)
|
||||
randomNet = randomRecurrent
|
||||
|
||||
-- | Load the data files and prepare a map of characters to a compressed int representation.
|
||||
loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Vector Char)
|
||||
loadShakespeare path = do
|
||||
contents <- lift $ readFile path
|
||||
let annotated = annotateCapitals contents
|
||||
(m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated
|
||||
hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
|
||||
return (V.fromList hot, m, cs)
|
||||
|
||||
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
|
||||
trainSlice !rate !net !recIns input offset size =
|
||||
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
|
||||
in case reverse e of
|
||||
(o : l : xs) ->
|
||||
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
|
||||
in trainRecurrent rate net recIns examples
|
||||
_ -> error "Not enough input"
|
||||
where
|
||||
x = fromMaybe (error "Hot variable didn't fit.")
|
||||
|
||||
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
|
||||
runShakespeare ShakespeareOpts {..} = do
|
||||
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
|
||||
(net0, i0) <- lift randomNet
|
||||
lift $ foldM_ (\(!net, !io) size -> do
|
||||
xs <- take (iterations `div` 15) <$> getRandomRs (0, length shakespeare - size - 1)
|
||||
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
|
||||
let results = take 100 $ generateParagraph trained bestInput oneHotMap oneHotDictionary ( S1D $ konst 0)
|
||||
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
|
||||
putStrLn (unAnnotateCapitals results)
|
||||
return (trained, bestInput)
|
||||
) (net0, i0) [10,10,15,15,20,20,25,25,30,30,35,35,40,40,50 :: Int]
|
||||
|
||||
generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes ~ 'D1 n, KnownNat n, Ord a)
|
||||
=> RecurrentNetwork layers shapes
|
||||
-> RecurrentInputs layers
|
||||
-> M.Map a Int
|
||||
-> Vector a
|
||||
-> S ('D1 n)
|
||||
-> [a]
|
||||
generateParagraph n s hotmap hotdict i =
|
||||
unfoldr go (s, i)
|
||||
where
|
||||
go (x, y) =
|
||||
do let (ns, o) = runRecurrent n x y
|
||||
un <- unHot hotdict o
|
||||
re <- makeHot hotmap un
|
||||
Just (un, (ns, re))
|
||||
|
||||
data ShakespeareOpts = ShakespeareOpts {
|
||||
trainingFile :: FilePath
|
||||
, iterations :: Int
|
||||
, rate :: LearningParameters
|
||||
}
|
||||
|
||||
shakespeare' :: Parser ShakespeareOpts
|
||||
shakespeare' = ShakespeareOpts <$> argument str (metavar "TRAIN")
|
||||
<*> option auto (long "examples" <> short 'e' <> value 1000000)
|
||||
<*> (LearningParameters
|
||||
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||
<*> option auto (long "momentum" <> value 0.95)
|
||||
<*> option auto (long "l2" <> value 0.000001)
|
||||
)
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
shopts <- execParser (info (shakespeare' <**> helper) idm)
|
||||
res <- runExceptT $ runShakespeare shopts
|
||||
case res of
|
||||
Right () -> pure ()
|
||||
Left err -> putStrLn err
|
||||
|
||||
|
||||
-- Replace capitals with an annotation and the lower case letter
|
||||
-- http://fastml.com/one-weird-trick-for-training-char-rnns/
|
||||
annotateCapitals :: String -> String
|
||||
annotateCapitals (x : rest)
|
||||
| isUpper x
|
||||
= '^' : toLower x : annotateCapitals rest
|
||||
| otherwise
|
||||
= x : annotateCapitals rest
|
||||
annotateCapitals []
|
||||
= []
|
||||
|
||||
unAnnotateCapitals :: String -> String
|
||||
unAnnotateCapitals ('^' : x : rest)
|
||||
= toUpper x : unAnnotateCapitals rest
|
||||
unAnnotateCapitals (x : rest)
|
||||
= x : unAnnotateCapitals rest
|
||||
unAnnotateCapitals []
|
||||
= []
|
||||
|
||||
-- | Tag the 'Nothing' value of a 'Maybe'
|
||||
note :: a -> Maybe b -> Either a b
|
||||
note a = maybe (Left a) Right
|
@ -1,15 +1,21 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-|
|
||||
Module : Grenade.Core.Network
|
||||
Description : Core definition a simple neural etwork
|
||||
Copyright : (c) Huw Campbell, 2016-2017
|
||||
License : BSD2
|
||||
Stability : experimental
|
||||
|
||||
This module defines the core data type for the simplest
|
||||
Neural network we support.
|
||||
|
||||
-}
|
||||
module Grenade.Core.Network (
|
||||
Layer (..)
|
||||
, Network (..)
|
||||
@ -20,10 +26,12 @@ module Grenade.Core.Network (
|
||||
) where
|
||||
|
||||
import Control.Monad.Random (MonadRandom)
|
||||
|
||||
import Data.List ( foldl' )
|
||||
import Data.Singletons
|
||||
|
||||
import Grenade.Core.Shape
|
||||
|
||||
-- | Learning parameters for stochastic gradient descent.
|
||||
data LearningParameters = LearningParameters {
|
||||
learningRate :: Double
|
||||
, learningMomentum :: Double
|
||||
@ -33,35 +41,43 @@ data LearningParameters = LearningParameters {
|
||||
-- | Class for updating a layer. All layers implement this, and it is
|
||||
-- shape independent.
|
||||
class Show x => UpdateLayer x where
|
||||
{-# MINIMAL runUpdate, createRandom #-}
|
||||
-- | The type for the gradient for this layer.
|
||||
-- Unit if there isn't a gradient to pass back.
|
||||
type Gradient x :: *
|
||||
-- | Update a layer with its gradient and learning parameters
|
||||
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
||||
|
||||
-- | Create a random layer, many layers will use pure
|
||||
createRandom :: MonadRandom m => m x
|
||||
|
||||
-- | Update a layer with many Gradients
|
||||
runUpdates :: LearningParameters -> x -> [Gradient x] -> x
|
||||
runUpdates rate = foldl' (runUpdate rate)
|
||||
|
||||
-- | Class for a layer. All layers implement this, however, they don't
|
||||
-- need to implement it for all shapes, only ones which are appropriate.
|
||||
class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
|
||||
-- | Used in training and scoring. Take the input from the previous
|
||||
-- layer, and give the output from this layer.
|
||||
runForwards :: x -> S' i -> S' o
|
||||
runForwards :: x -> S i -> S o
|
||||
-- | Back propagate a step. Takes the current layer, the input that the
|
||||
-- layer gave from the input and the back propagated derivatives from
|
||||
-- the layer above.
|
||||
-- Returns the gradient layer and the derivatives to push back further.
|
||||
runBackwards :: x -> S' i -> S' o -> (Gradient x, S' i)
|
||||
runBackwards :: x -> S i -> S o -> (Gradient x, S i)
|
||||
|
||||
-- | Type of a network.
|
||||
-- The [*] type specifies the types of the layers. This is needed for parallel
|
||||
-- running and being all the gradients beck together.
|
||||
--
|
||||
-- The [*] type specifies the types of the layers.
|
||||
--
|
||||
-- The [Shape] type specifies the shapes of data passed between the layers.
|
||||
-- Could be considered to be a heterogeneous list of layers which are able to
|
||||
--
|
||||
-- Can be considered to be a heterogeneous list of layers which are able to
|
||||
-- transform the data shapes of the network.
|
||||
data Network :: [*] -> [Shape] -> * where
|
||||
O :: Layer x i o => !x -> Network '[x] '[i, o]
|
||||
(:~>) :: Layer x i h => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
|
||||
O :: (SingI i, SingI o, Layer x i o) => !x -> Network '[x] '[i, o]
|
||||
(:~>) :: (SingI i, SingI h, Layer x i h) => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
|
||||
infixr 5 :~>
|
||||
|
||||
instance Show (Network l h) where
|
||||
@ -74,15 +90,14 @@ data Gradients :: [*] -> * where
|
||||
OG :: UpdateLayer x => Gradient x -> Gradients '[x]
|
||||
(:/>) :: UpdateLayer x => Gradient x -> Gradients xs -> Gradients (x ': xs)
|
||||
|
||||
|
||||
-- | A network can easily be created by hand with (:~>), but an easy way to initialise a random
|
||||
-- network is with the randomNetwork.
|
||||
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
|
||||
-- | Create a network of the types requested
|
||||
randomNetwork :: MonadRandom m => m (Network xs ss)
|
||||
|
||||
instance Layer x i o => CreatableNetwork (x ': '[]) (i ': o ': '[]) where
|
||||
instance (SingI i, SingI o, Layer x i o) => CreatableNetwork (x ': '[]) (i ': o ': '[]) where
|
||||
randomNetwork = O <$> createRandom
|
||||
|
||||
instance (Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
|
||||
instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
|
||||
randomNetwork = (:~>) <$> createRandom <*> randomNetwork
|
||||
|
@ -4,7 +4,16 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-|
|
||||
Module : Grenade.Core.Shape
|
||||
Description : Core definition of the Shapes of data we understand
|
||||
Copyright : (c) Huw Campbell, 2016-2017
|
||||
License : BSD2
|
||||
Stability : experimental
|
||||
|
||||
This module defines simple back propagation and training functions
|
||||
for a network.
|
||||
-}
|
||||
module Grenade.Core.Runner (
|
||||
train
|
||||
, backPropagate
|
||||
@ -16,16 +25,22 @@ import Data.Singletons.Prelude
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
-- | Drive and network and collect its back propogated gradients.
|
||||
backPropagate :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
|
||||
=> Network layers shapes -> S' input -> S' output -> Gradients layers
|
||||
-- | Perform reverse automatic differentiation on the network
|
||||
-- for the current input and expected output.
|
||||
--
|
||||
-- /Note:/ The loss function pushed backwards is appropriate
|
||||
-- for both regression and classification as a squared loss
|
||||
-- or log-loss respectively. Other loss functions are not yet
|
||||
-- implemented.
|
||||
backPropagate :: forall shapes layers.
|
||||
Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Gradients layers
|
||||
backPropagate network input target =
|
||||
fst $ go input network
|
||||
where
|
||||
go :: forall j js sublayers. (Head js ~ j, Last js ~ output)
|
||||
=> S' j -- ^ input vector
|
||||
go :: forall js sublayers. (Last js ~ Last shapes)
|
||||
=> S (Head js) -- ^ input vector
|
||||
-> Network sublayers js -- ^ network to train
|
||||
-> (Gradients sublayers, S' j)
|
||||
-> (Gradients sublayers, S (Head js))
|
||||
-- handle input from the beginning, feeding upwards.
|
||||
go !x (layer :~> n)
|
||||
= let y = runForwards layer x
|
||||
@ -44,16 +59,7 @@ backPropagate network input target =
|
||||
|
||||
in (OG layer', dWs)
|
||||
|
||||
-- | Update a network with new weights after training with an instance.
|
||||
train :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
|
||||
=> LearningParameters -- ^ learning rate
|
||||
-> Network layers shapes -- ^ network to train
|
||||
-> S' input -> S' output -- ^ target vector
|
||||
-> Network layers shapes
|
||||
train rate network input output =
|
||||
let grads = backPropagate network input output
|
||||
in applyUpdate rate network grads
|
||||
|
||||
-- | Apply one step of stochastic gradient decent across the network.
|
||||
applyUpdate :: LearningParameters -> Network ls ss -> Gradients ls -> Network ls ss
|
||||
applyUpdate rate (O layer) (OG gradient)
|
||||
= O (runUpdate rate layer gradient)
|
||||
@ -62,9 +68,13 @@ applyUpdate rate (layer :~> rest) (gradient :/> grest)
|
||||
applyUpdate _ _ _
|
||||
= error "Impossible for the gradients of a network to have a different length to the network"
|
||||
|
||||
-- | Just forwards propagation with no training.
|
||||
runNet :: Network layers hs
|
||||
-> S' (Head hs) -- ^ input vector
|
||||
-> S' (Last hs) -- ^ target vector
|
||||
-- | Update a network with new weights after training with an instance.
|
||||
train :: LearningParameters -> Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Network layers shapes
|
||||
train rate network input output =
|
||||
let grads = backPropagate network input output
|
||||
in applyUpdate rate network grads
|
||||
|
||||
-- | Run the network with input and return the given output.
|
||||
runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes)
|
||||
runNet (layer :~> n) !x = let y = runForwards layer x in runNet n y
|
||||
runNet (O layer) !x = runForwards layer x
|
||||
|
@ -1,75 +1,171 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
-- Ghc 8.0 gives a warning on `(+) _ _ = error ...` but ghc 7.10 fails to
|
||||
-- Ghc 8.0 gives a warning on `n2 _ _ = error ...` but ghc 7.10 fails to
|
||||
-- compile without this default pattern.
|
||||
{-# OPTIONS_GHC -fno-warn-overlapping-patterns #-}
|
||||
|
||||
{-|
|
||||
Module : Grenade.Core.Shape
|
||||
Description : Core definition of the Shapes of data we understand
|
||||
Copyright : (c) Huw Campbell, 2016-2017
|
||||
License : BSD2
|
||||
Stability : experimental
|
||||
|
||||
This module defines the core data types for the shapes of data that
|
||||
are understood by Grenade.
|
||||
-}
|
||||
module Grenade.Core.Shape (
|
||||
Shape (..)
|
||||
, S' (..)
|
||||
, S (..)
|
||||
, randomOfShape
|
||||
, fromStorable
|
||||
) where
|
||||
|
||||
import Control.DeepSeq (NFData (..))
|
||||
import Control.Monad.Random ( MonadRandom, getRandom )
|
||||
|
||||
import Data.Singletons
|
||||
import Data.Singletons.TypeLits
|
||||
import Data.Vector.Storable ( Vector )
|
||||
import qualified Data.Vector.Storable as V
|
||||
|
||||
import GHC.TypeLits
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as H
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
import qualified Numeric.LinearAlgebra as NLA
|
||||
|
||||
-- | The current shapes we accept.
|
||||
-- at the moment this is just one, two, and three dimensional
|
||||
-- Vectors/Matricies.
|
||||
data Shape =
|
||||
D1 Nat
|
||||
data Shape
|
||||
= D1 Nat
|
||||
| D2 Nat Nat
|
||||
| D3 Nat Nat Nat
|
||||
|
||||
instance Num (S' x) where
|
||||
(+) (S1D' x) (S1D' y) = S1D' (x + y)
|
||||
(+) (S2D' x) (S2D' y) = S2D' (x + y)
|
||||
(+) (S3D' x) (S3D' y) = S3D' (x + y)
|
||||
(+) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
(-) (S1D' x) (S1D' y) = S1D' (x - y)
|
||||
(-) (S2D' x) (S2D' y) = S2D' (x - y)
|
||||
(-) (S3D' x) (S3D' y) = S3D' (x - y)
|
||||
(-) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
(*) (S1D' x) (S1D' y) = S1D' (x * y)
|
||||
(*) (S2D' x) (S2D' y) = S2D' (x * y)
|
||||
(*) (S3D' x) (S3D' y) = S3D' (x * y)
|
||||
(*) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
abs (S1D' x) = S1D' (abs x)
|
||||
abs (S2D' x) = S2D' (abs x)
|
||||
abs (S3D' x) = S3D' (abs x)
|
||||
|
||||
signum (S1D' x) = S1D' (signum x)
|
||||
signum (S2D' x) = S2D' (signum x)
|
||||
signum (S3D' x) = S3D' (signum x)
|
||||
|
||||
fromInteger _ = error "Unimplemented: fromInteger on Shape"
|
||||
|
||||
-- | Given a Shape n, these are the possible data structures with that shape.
|
||||
-- All shapes are held in contiguous memory.
|
||||
-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
|
||||
data S' (n :: Shape) where
|
||||
S1D' :: ( KnownNat o ) => R o -> S' ('D1 o)
|
||||
S2D' :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S' ('D2 rows columns)
|
||||
S3D' :: ( KnownNat rows
|
||||
, KnownNat columns
|
||||
, KnownNat depth
|
||||
, KnownNat (rows * depth)) => L (rows * depth) columns -> S' ('D3 rows columns depth)
|
||||
data S (n :: Shape) where
|
||||
S1D :: ( KnownNat o ) => R o -> S ('D1 o)
|
||||
S2D :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S ('D2 rows columns)
|
||||
S3D :: ( KnownNat rows
|
||||
, KnownNat columns
|
||||
, KnownNat depth
|
||||
, KnownNat (rows * depth)) => L (rows * depth) columns -> S ('D3 rows columns depth)
|
||||
|
||||
instance Show (S' n) where
|
||||
show (S1D' a) = "S1D' " ++ show a
|
||||
show (S2D' a) = "S2D' " ++ show a
|
||||
show (S3D' a) = "S3D' " ++ show a
|
||||
deriving instance Show (S n)
|
||||
|
||||
instance SingI x => Num (S x) where
|
||||
(+) = n2 (+)
|
||||
(-) = n2 (-)
|
||||
(*) = n2 (*)
|
||||
abs = n1 abs
|
||||
signum = n1 signum
|
||||
fromInteger x = case (sing :: Sing x) of
|
||||
D1Sing -> S1D (konst $ fromInteger x)
|
||||
D2Sing -> S2D (konst $ fromInteger x)
|
||||
D3Sing -> S3D (konst $ fromInteger x)
|
||||
|
||||
instance SingI x => Fractional (S x) where
|
||||
(/) = n2 (/)
|
||||
recip = n1 recip
|
||||
fromRational x = case (sing :: Sing x) of
|
||||
D1Sing -> S1D (konst $ fromRational x)
|
||||
D2Sing -> S2D (konst $ fromRational x)
|
||||
D3Sing -> S3D (konst $ fromRational x)
|
||||
|
||||
instance SingI x => Floating (S x) where
|
||||
pi = case (sing :: Sing x) of
|
||||
D1Sing -> S1D (konst pi)
|
||||
D2Sing -> S2D (konst pi)
|
||||
D3Sing -> S3D (konst pi)
|
||||
exp = n1 exp
|
||||
log = n1 log
|
||||
sqrt = n1 sqrt
|
||||
(**) = n2 (**)
|
||||
logBase = n2 logBase
|
||||
sin = n1 sin
|
||||
cos = n1 cos
|
||||
tan = n1 tan
|
||||
asin = n1 asin
|
||||
acos = n1 acos
|
||||
atan = n1 atan
|
||||
sinh = n1 sinh
|
||||
cosh = n1 cosh
|
||||
tanh = n1 tanh
|
||||
asinh = n1 asinh
|
||||
acosh = n1 acosh
|
||||
atanh = n1 atanh
|
||||
|
||||
-- Singletons
|
||||
-- These could probably be derived with template haskell, but this seems
|
||||
-- clear and makes adding the KnownNat constraints simple.
|
||||
data instance Sing (n :: Shape) where
|
||||
D1Sing :: KnownNat a => Sing ('D1 a)
|
||||
D2Sing :: (KnownNat a, KnownNat b) => Sing ('D2 a b)
|
||||
D3Sing :: (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => Sing ('D3 a b c)
|
||||
|
||||
instance KnownNat a => SingI ('D1 a) where
|
||||
sing = D1Sing
|
||||
instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
|
||||
sing = D2Sing
|
||||
instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
|
||||
sing = D3Sing
|
||||
|
||||
--
|
||||
-- I haven't made shapes strict, as sometimes they're not needed
|
||||
-- (the last input gradient back for instance)
|
||||
--
|
||||
instance NFData (S x) where
|
||||
rnf (S1D x) = rnf x
|
||||
rnf (S2D x) = rnf x
|
||||
rnf (S3D x) = rnf x
|
||||
|
||||
-- | Generate random data of the desired shape
|
||||
randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x)
|
||||
randomOfShape = do
|
||||
seed :: Int <- getRandom
|
||||
return $ case (sing :: Sing x) of
|
||||
D1Sing -> S1D (randomVector seed Uniform * 2 - 1)
|
||||
D2Sing -> S2D (uniformSample seed (-1) 1)
|
||||
D3Sing -> S3D (uniformSample seed (-1) 1)
|
||||
|
||||
-- | Generate a shape from a Storable Vector.
|
||||
--
|
||||
-- Returns Nothing if the vector is of the wrong size.
|
||||
fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x)
|
||||
fromStorable xs = case sing :: Sing x of
|
||||
D1Sing -> S1D <$> H.create xs
|
||||
D2Sing -> S2D <$> mkL xs
|
||||
D3Sing -> S3D <$> mkL xs
|
||||
where
|
||||
mkL :: forall rows columns. (KnownNat rows, KnownNat columns)
|
||||
=> Vector Double -> Maybe (L rows columns)
|
||||
mkL v =
|
||||
let rows = fromIntegral $ natVal (Proxy :: Proxy rows)
|
||||
columns = fromIntegral $ natVal (Proxy :: Proxy columns)
|
||||
in if rows * columns == V.length v
|
||||
then H.create $ NLA.reshape columns v
|
||||
else Nothing
|
||||
|
||||
-- Helper function for creating the number instances
|
||||
n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x
|
||||
n1 f (S1D x) = S1D (f x)
|
||||
n1 f (S2D x) = S2D (f x)
|
||||
n1 f (S3D x) = S3D (f x)
|
||||
|
||||
-- Helper function for creating the number instances
|
||||
n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x
|
||||
n2 f (S1D x) (S1D y) = S1D (f x y)
|
||||
n2 f (S2D x) (S2D y) = S2D (f x y)
|
||||
n2 f (S3D x) (S3D y) = S3D (f x y)
|
||||
n2 _ _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
37
src/Grenade/Graph/GraphNetwork.hs
Normal file
37
src/Grenade/Graph/GraphNetwork.hs
Normal file
@ -0,0 +1,37 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module Grenade.Graph.Network (
|
||||
Layer (..)
|
||||
, UpdateLayer (..)
|
||||
) where
|
||||
|
||||
import Control.Monad.Random (MonadRandom)
|
||||
import Data.Singletons
|
||||
import Data.Singletons.Prelude
|
||||
|
||||
import GHC.TypeLits
|
||||
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Network ( UpdateLayer (..), Layer (..) )
|
||||
|
||||
-- | Type of a DAG network
|
||||
|
||||
data Fin :: Nat -> * where
|
||||
Fin0 :: Fin (n + 1)
|
||||
FinS :: Fin n -> Fin (n + 1)
|
||||
|
||||
data Edge :: Nat -> * where
|
||||
Edge :: Shape -> Fin n -> Edge n
|
||||
|
||||
data Node a n where
|
||||
Node :: a -> [Edge n] -> Node a n
|
@ -1,7 +1,5 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
@ -9,9 +7,6 @@
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE PatternGuards #-}
|
||||
|
||||
module Grenade.Layers.Convolution (
|
||||
Convolution (..)
|
||||
, Convolution' (..)
|
||||
@ -31,6 +26,7 @@ import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
import Grenade.Layers.Internal.Update
|
||||
|
||||
-- | A convolution layer for a neural network.
|
||||
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
||||
@ -43,12 +39,12 @@ import Grenade.Layers.Internal.Convolution
|
||||
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
||||
--
|
||||
-- One probably shouldn't build their own layer, but rather use the randomConvolution function.
|
||||
data Convolution :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance.
|
||||
-> Nat -- ^ Number of filters, this is the number of channels output by the layer.
|
||||
-> Nat -- ^ The number of rows in the kernel filter
|
||||
-> Nat -- ^ The number of column in the kernel filter
|
||||
-> Nat -- ^ The row stride of the convolution filter
|
||||
-> Nat -- ^ The columns stride of the convolution filter
|
||||
data Convolution :: Nat -- Number of channels, for the first layer this could be RGB for instance.
|
||||
-> Nat -- Number of filters, this is the number of channels output by the layer.
|
||||
-> Nat -- The number of rows in the kernel filter
|
||||
-> Nat -- The number of column in the kernel filter
|
||||
-> Nat -- The row stride of the convolution filter
|
||||
-> Nat -- The columns stride of the convolution filter
|
||||
-> * where
|
||||
Convolution :: ( KnownNat channels
|
||||
, KnownNat filters
|
||||
@ -58,16 +54,16 @@ data Convolution :: Nat -- ^ Number of channels, for the first layer this could
|
||||
, KnownNat strideColumns
|
||||
, KnownNat kernelFlattened
|
||||
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||
=> !(L kernelFlattened filters) -- ^ The kernel filter weights
|
||||
-> !(L kernelFlattened filters) -- ^ The last kernel update (or momentum)
|
||||
=> !(L kernelFlattened filters) -- The kernel filter weights
|
||||
-> !(L kernelFlattened filters) -- The last kernel update (or momentum)
|
||||
-> Convolution channels filters kernelRows kernelColumns strideRows strideColumns
|
||||
|
||||
data Convolution' :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance.
|
||||
-> Nat -- ^ Number of filters, this is the number of channels output by the layer.
|
||||
-> Nat -- ^ The number of rows in the kernel filter
|
||||
-> Nat -- ^ The number of column in the kernel filter
|
||||
-> Nat -- ^ The row stride of the convolution filter
|
||||
-> Nat -- ^ The columns stride of the convolution filter
|
||||
data Convolution' :: Nat -- Number of channels, for the first layer this could be RGB for instance.
|
||||
-> Nat -- Number of filters, this is the number of channels output by the layer.
|
||||
-> Nat -- The number of rows in the kernel filter
|
||||
-> Nat -- The number of column in the kernel filter
|
||||
-> Nat -- The row stride of the convolution filter
|
||||
-> Nat -- The columns stride of the convolution filter
|
||||
-> * where
|
||||
Convolution' :: ( KnownNat channels
|
||||
, KnownNat filters
|
||||
@ -77,7 +73,7 @@ data Convolution' :: Nat -- ^ Number of channels, for the first layer this could
|
||||
, KnownNat strideColumns
|
||||
, KnownNat kernelFlattened
|
||||
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||
=> !(L kernelFlattened filters) -- ^ The kernel filter gradient
|
||||
=> !(L kernelFlattened filters) -- The kernel filter gradient
|
||||
-> Convolution' channels filters kernelRows kernelColumns strideRows strideColumns
|
||||
|
||||
instance Show (Convolution c f k k' s s') where
|
||||
@ -109,7 +105,7 @@ randomConvolution :: ( MonadRandom m
|
||||
, kernelFlattened ~ (kernelRows * kernelColumns * channels))
|
||||
=> m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||
randomConvolution = do
|
||||
s :: Int <- getRandom
|
||||
s <- getRandom
|
||||
let wN = uniformSample s (-1) 1
|
||||
mm = konst 0
|
||||
return $ Convolution wN mm
|
||||
@ -124,9 +120,7 @@ instance ( KnownNat channels
|
||||
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
||||
type Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols) = (Convolution' channels filters kernelRows kernelCols strideRows strideCols)
|
||||
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
|
||||
let newMomentum = konst learningMomentum * oldMomentum - konst learningRate * kernelGradient
|
||||
regulariser = konst (learningRegulariser * learningRate) * oldKernel
|
||||
newKernel = oldKernel + newMomentum - regulariser
|
||||
let (newKernel, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
|
||||
in Convolution newKernel newMomentum
|
||||
|
||||
createRandom = randomConvolution
|
||||
@ -146,7 +140,7 @@ instance ( KnownNat kernelRows
|
||||
, KnownNat (kernelRows * kernelCols * 1)
|
||||
, KnownNat (outputRows * filters)
|
||||
) => Layer (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where
|
||||
runForwards (Convolution kernel _) (S2D' input) =
|
||||
runForwards (Convolution kernel _) (S2D input) =
|
||||
let ex = extract input
|
||||
ek = extract kernel
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
@ -159,9 +153,9 @@ instance ( KnownNat kernelRows
|
||||
mt = c LA.<> ek
|
||||
r = col2vid 1 1 1 1 ox oy mt
|
||||
rs = fromJust . create $ r
|
||||
in S3D' rs
|
||||
in S3D rs
|
||||
|
||||
runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
|
||||
runBackwards (Convolution kernel _) (S2D input) (S3D dEdy) =
|
||||
let ex = extract input
|
||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||
@ -183,7 +177,7 @@ instance ( KnownNat kernelRows
|
||||
dW = vs LA.<> tr ek
|
||||
|
||||
xW = col2im kx ky sx sy ix iy dW
|
||||
in (Convolution' kN, S2D' . fromJust . create $ xW)
|
||||
in (Convolution' kN, S2D . fromJust . create $ xW)
|
||||
|
||||
|
||||
-- | A three dimensional image (or 2d with many channels) can have
|
||||
@ -203,7 +197,7 @@ instance ( KnownNat kernelRows
|
||||
, KnownNat (kernelRows * kernelCols * channels)
|
||||
, KnownNat (outputRows * filters)
|
||||
) => Layer (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where
|
||||
runForwards (Convolution kernel _) (S3D' input) =
|
||||
runForwards (Convolution kernel _) (S3D input) =
|
||||
let ex = extract input
|
||||
ek = extract kernel
|
||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
@ -219,8 +213,8 @@ instance ( KnownNat kernelRows
|
||||
mt = c LA.<> ek
|
||||
r = col2vid 1 1 1 1 ox oy mt
|
||||
rs = fromJust . create $ r
|
||||
in S3D' rs
|
||||
runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
|
||||
in S3D rs
|
||||
runBackwards (Convolution kernel _) (S3D input) (S3D dEdy) =
|
||||
let ex = extract input
|
||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||
@ -243,4 +237,4 @@ instance ( KnownNat kernelRows
|
||||
dW = vs LA.<> tr ek
|
||||
|
||||
xW = col2vid kx ky sx sy ix iy dW
|
||||
in (Convolution' kN, S3D' . fromJust . create $ xW)
|
||||
in (Convolution' kN, S3D . fromJust . create $ xW)
|
||||
|
@ -4,10 +4,6 @@
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
|
||||
module Grenade.Layers.Crop (
|
||||
Crop (..)
|
||||
) where
|
||||
@ -50,19 +46,19 @@ instance ( KnownNat cropLeft
|
||||
, (inputRows - cropTop - cropBottom) ~ outputRows
|
||||
, (inputColumns - cropLeft - cropRight) ~ outputColumns
|
||||
) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||
runForwards Crop (S2D' input) =
|
||||
runForwards Crop (S2D input) =
|
||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||
nrows = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
ncols = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
||||
m = extract input
|
||||
r = subMatrix (cropt, cropl) (nrows, ncols) m
|
||||
in S2D' . fromJust . create $ r
|
||||
runBackwards _ _ (S2D' dEdy) =
|
||||
in S2D . fromJust . create $ r
|
||||
runBackwards _ _ (S2D dEdy) =
|
||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
|
||||
cropb = fromIntegral $ natVal (Proxy :: Proxy cropBottom)
|
||||
eo = extract dEdy
|
||||
vs = diagBlock [konst 0 (cropt,cropl), eo, konst 0 (cropb,cropr)]
|
||||
in ((), S2D' . fromJust . create $ vs)
|
||||
in ((), S2D . fromJust . create $ vs)
|
||||
|
@ -1,12 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module Grenade.Layers.Dropout (
|
||||
Dropout (..)
|
||||
, randomDropout
|
||||
@ -45,7 +40,7 @@ randomDropout rate = do
|
||||
return $ Dropout xs
|
||||
|
||||
instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where
|
||||
runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops
|
||||
runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x
|
||||
runBackwards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops)
|
||||
runBackwards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x)
|
||||
runForwards (Dropout drops) (S1D x) = S1D $ x * drops
|
||||
runForwards (Pass rate) (S1D x)= S1D $ dvmap (* (1 - rate)) x
|
||||
runBackwards (Dropout drops) _ (S1D x) = ((), S1D $ x * drops)
|
||||
runBackwards (Pass rate) _ (S1D x) = ((), S1D $ dvmap (* (1 - rate)) x)
|
||||
|
@ -1,13 +1,8 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Grenade.Layers.Flatten (
|
||||
FlattenLayer (..)
|
||||
) where
|
||||
@ -16,11 +11,16 @@ import Data.Singletons.TypeLits
|
||||
import GHC.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra.Static
|
||||
import Numeric.LinearAlgebra.Data as LA (flatten, toList)
|
||||
import Numeric.LinearAlgebra.Data as LA ( flatten )
|
||||
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Network
|
||||
|
||||
-- | Flatten Layer
|
||||
--
|
||||
-- Flattens input down to D1 from either 2D or 3D data.
|
||||
--
|
||||
-- Can also be used to turn a 3D image with only one channel into a 2D image.
|
||||
data FlattenLayer = FlattenLayer
|
||||
deriving Show
|
||||
|
||||
@ -29,11 +29,18 @@ instance UpdateLayer FlattenLayer where
|
||||
runUpdate _ _ _ = FlattenLayer
|
||||
createRandom = return FlattenLayer
|
||||
|
||||
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * z)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
||||
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
|
||||
runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
|
||||
runForwards _ (S2D y) = fromJust' . fromStorable . flatten . extract $ y
|
||||
runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
|
||||
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
||||
runForwards _ (S3D' y) = S1D' . fromList . toList . flatten . extract $ y
|
||||
runBackwards _ _ (S1D' y) = ((), S3D' . fromList . toList . unwrap $ y)
|
||||
runForwards _ (S3D y) = fromJust' . fromStorable . flatten . extract $ y
|
||||
runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
|
||||
|
||||
instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer FlattenLayer ('D3 x y z) ('D2 x y) where
|
||||
runForwards _ (S3D y) = S2D y
|
||||
runBackwards _ _ (S2D y) = ((), S3D y)
|
||||
|
||||
fromJust' :: Maybe x -> x
|
||||
fromJust' (Just x) = x
|
||||
fromJust' Nothing = error $ "FlattenLayer error: data shape couldn't be converted."
|
||||
|
@ -1,11 +1,8 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Grenade.Layers.FullyConnected (
|
||||
FullyConnected (..)
|
||||
, randomFullyConnected
|
||||
@ -20,6 +17,8 @@ import Numeric.LinearAlgebra.Static
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
import Grenade.Layers.Internal.Update
|
||||
|
||||
-- | A basic fully connected (or inner product) neural network layer.
|
||||
data FullyConnected i o = FullyConnected
|
||||
!(R o) -- Bias neuron weights
|
||||
@ -38,32 +37,29 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
|
||||
type Gradient (FullyConnected i o) = (FullyConnected' i o)
|
||||
|
||||
runUpdate LearningParameters {..} (FullyConnected oldBias oldBiasMomentum oldActivations oldMomentum) (FullyConnected' biasGradient activationGradient) =
|
||||
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
|
||||
newBias = oldBias + newBiasMomentum
|
||||
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
|
||||
regulariser = konst (learningRegulariser * learningRate) * oldActivations
|
||||
newActivations = oldActivations + newMomentum - regulariser
|
||||
let (newBias, newBiasMomentum) = decendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum
|
||||
(newActivations, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum
|
||||
in FullyConnected newBias newBiasMomentum newActivations newMomentum
|
||||
|
||||
createRandom = randomFullyConnected
|
||||
|
||||
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
|
||||
-- Do a matrix vector multiplication and return the result.
|
||||
runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)
|
||||
runForwards (FullyConnected wB _ wN _) (S1D v) = S1D (wB + wN #> v)
|
||||
|
||||
-- Run a backpropogation step for a full connected layer.
|
||||
runBackwards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
|
||||
runBackwards (FullyConnected _ _ wN _) (S1D x) (S1D dEdy) =
|
||||
let wB' = dEdy
|
||||
mm' = dEdy `outer` x
|
||||
-- calcluate derivatives for next step
|
||||
dWs = tr wN #> dEdy
|
||||
in (FullyConnected' wB' mm', S1D' dWs)
|
||||
in (FullyConnected' wB' mm', S1D dWs)
|
||||
|
||||
randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o)
|
||||
=> m (FullyConnected i o)
|
||||
randomFullyConnected = do
|
||||
s1 :: Int <- getRandom
|
||||
s2 :: Int <- getRandom
|
||||
s1 <- getRandom
|
||||
s2 <- getRandom
|
||||
let wB = randomVector s1 Uniform * 2 - 1
|
||||
wN = uniformSample s2 (-1) 1
|
||||
bm = konst 0
|
||||
|
@ -1,16 +1,11 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
|
||||
module Grenade.Layers.Fuse (
|
||||
Fuse (..)
|
||||
) where
|
||||
@ -42,11 +37,11 @@ instance (Layer x i h, Layer y h o) => UpdateLayer (Fuse x y i h o) where
|
||||
|
||||
instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
|
||||
runForwards (x :$$ y) input =
|
||||
let yInput :: S' h = runForwards x input
|
||||
let yInput :: S h = runForwards x input
|
||||
in runForwards y yInput
|
||||
|
||||
runBackwards (x :$$ y) input backGradient =
|
||||
let yInput :: S' h = runForwards x input
|
||||
let yInput :: S h = runForwards x input
|
||||
(y', yGrad) = runBackwards y yInput backGradient
|
||||
(x', xGrad) = runBackwards x input yGrad
|
||||
in ((x', y'), xGrad)
|
||||
|
@ -6,7 +6,9 @@ module Grenade.Layers.Internal.Convolution (
|
||||
, vid2col
|
||||
) where
|
||||
|
||||
import Foreign ( mallocForeignPtrArray0, withForeignPtr )
|
||||
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
|
||||
|
||||
import Foreign ( mallocForeignPtrArray, withForeignPtr )
|
||||
import Foreign.Ptr ( Ptr )
|
||||
|
||||
import Numeric.LinearAlgebra ( Matrix, flatten, rows, cols )
|
||||
@ -28,19 +30,19 @@ col2im_c :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Ma
|
||||
col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol =
|
||||
let vec = flatten dataCol
|
||||
in unsafePerformIO $ do
|
||||
outPtr <- mallocForeignPtrArray0 (height * width * channels)
|
||||
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
|
||||
outPtr <- mallocForeignPtrArray (height * width * channels)
|
||||
let (inPtr, _) = U.unsafeToForeignPtr0 vec
|
||||
|
||||
withForeignPtr inPtr $ \inPtr' ->
|
||||
withForeignPtr outPtr $ \outPtr' ->
|
||||
col2im_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
col2im_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
|
||||
let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
|
||||
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
||||
|
||||
foreign import ccall unsafe
|
||||
col2im_cpu
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
||||
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid =
|
||||
@ -63,16 +65,16 @@ im2col_c channels height width kernelRows kernelColumns strideRows strideColumns
|
||||
kernelSize = kernelRows * kernelColumns
|
||||
numberOfPatches = rowOut * colOut
|
||||
in unsafePerformIO $ do
|
||||
outPtr <- mallocForeignPtrArray0 (numberOfPatches * kernelSize * channels)
|
||||
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
|
||||
outPtr <- mallocForeignPtrArray (numberOfPatches * kernelSize * channels)
|
||||
let (inPtr, _) = U.unsafeToForeignPtr0 vec
|
||||
|
||||
withForeignPtr inPtr $ \inPtr' ->
|
||||
withForeignPtr outPtr $ \outPtr' ->
|
||||
im2col_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
im2col_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * kernelSize * channels)
|
||||
let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * kernelSize * channels)
|
||||
return $ U.matrixFromVector U.RowMajor numberOfPatches (kernelSize * channels) matVec
|
||||
|
||||
foreign import ccall unsafe
|
||||
im2col_cpu
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
@ -4,7 +4,9 @@ module Grenade.Layers.Internal.Pooling (
|
||||
, poolBackward
|
||||
) where
|
||||
|
||||
import Foreign ( mallocForeignPtrArray0, withForeignPtr )
|
||||
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
|
||||
|
||||
import Foreign ( mallocForeignPtrArray, withForeignPtr )
|
||||
import Foreign.Ptr ( Ptr )
|
||||
|
||||
import Numeric.LinearAlgebra ( Matrix , flatten )
|
||||
@ -19,37 +21,37 @@ poolForward channels height width kernelRows kernelColumns strideRows strideColu
|
||||
colOut = (width - kernelColumns) `div` strideColumns + 1
|
||||
numberOfPatches = rowOut * colOut
|
||||
in unsafePerformIO $ do
|
||||
outPtr <- mallocForeignPtrArray0 (numberOfPatches * channels)
|
||||
let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
|
||||
outPtr <- mallocForeignPtrArray (numberOfPatches * channels)
|
||||
let (inPtr, _) = U.unsafeToForeignPtr0 vec
|
||||
|
||||
withForeignPtr inPtr $ \inPtr' ->
|
||||
withForeignPtr outPtr $ \outPtr' ->
|
||||
pool_forwards_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
pool_forwards_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * channels)
|
||||
let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * channels)
|
||||
return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec
|
||||
|
||||
foreign import ccall unsafe
|
||||
pool_forwards_cpu
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad =
|
||||
let vecIm = flatten dataIm
|
||||
vecGrad = flatten dataGrad
|
||||
in unsafePerformIO $ do
|
||||
outPtr <- mallocForeignPtrArray0 (height * width * channels)
|
||||
let (imPtr, imOffset, _) = U.unsafeToForeignPtr vecIm
|
||||
let (gradPtr, gradOffset, _) = U.unsafeToForeignPtr vecGrad
|
||||
outPtr <- mallocForeignPtrArray (height * width * channels)
|
||||
let (imPtr, _) = U.unsafeToForeignPtr0 vecIm
|
||||
let (gradPtr, _) = U.unsafeToForeignPtr0 vecGrad
|
||||
|
||||
withForeignPtr imPtr $ \imPtr' ->
|
||||
withForeignPtr gradPtr $ \gradPtr' ->
|
||||
withForeignPtr outPtr $ \outPtr' ->
|
||||
pool_backwards_cpu imPtr' imOffset gradPtr' gradOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
pool_backwards_cpu imPtr' gradPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
|
||||
|
||||
let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
|
||||
let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
|
||||
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
|
||||
|
||||
foreign import ccall unsafe
|
||||
pool_backwards_cpu
|
||||
:: Ptr Double -> Int -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
:: Ptr Double -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
|
||||
|
70
src/Grenade/Layers/Internal/Update.hs
Normal file
70
src/Grenade/Layers/Internal/Update.hs
Normal file
@ -0,0 +1,70 @@
|
||||
{-# LANGUAGE ForeignFunctionInterface #-}
|
||||
module Grenade.Layers.Internal.Update (
|
||||
decendMatrix
|
||||
, decendVector
|
||||
) where
|
||||
|
||||
import Data.Maybe ( fromJust )
|
||||
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
|
||||
|
||||
import Foreign ( mallocForeignPtrArray, withForeignPtr )
|
||||
import Foreign.Ptr ( Ptr )
|
||||
import GHC.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra ( Vector, flatten )
|
||||
import Numeric.LinearAlgebra.Static
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
import System.IO.Unsafe ( unsafePerformIO )
|
||||
|
||||
decendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
|
||||
decendMatrix rate momentum regulariser weights gradient lastUpdate =
|
||||
let (rows, cols) = size weights
|
||||
len = rows * cols
|
||||
-- Most gradients come in in ColumnMajor,
|
||||
-- so we'll transpose here before flattening them
|
||||
-- into a vector to prevent a copy.
|
||||
--
|
||||
-- This gives ~15% speed improvement for LSTMs.
|
||||
weights' = flatten . tr . extract $ weights
|
||||
gradient' = flatten . tr . extract $ gradient
|
||||
lastUpdate' = flatten . tr . extract $ lastUpdate
|
||||
(vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
|
||||
|
||||
-- Note that it's ColumnMajor, as we did a transpose before
|
||||
-- using the internal vectors.
|
||||
mw = U.matrixFromVector U.ColumnMajor rows cols vw
|
||||
mm = U.matrixFromVector U.ColumnMajor rows cols vm
|
||||
in (fromJust . create $ mw, fromJust . create $ mm)
|
||||
|
||||
decendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
|
||||
decendVector rate momentum regulariser weights gradient lastUpdate =
|
||||
let len = size weights
|
||||
weights' = extract weights
|
||||
gradient' = extract gradient
|
||||
lastUpdate' = extract lastUpdate
|
||||
(vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
|
||||
in (fromJust $ create vw, fromJust $ create vm)
|
||||
|
||||
decendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
|
||||
decendUnsafe len rate momentum regulariser weights gradient lastUpdate =
|
||||
unsafePerformIO $ do
|
||||
outWPtr <- mallocForeignPtrArray len
|
||||
outMPtr <- mallocForeignPtrArray len
|
||||
let (wPtr, _) = U.unsafeToForeignPtr0 weights
|
||||
let (gPtr, _) = U.unsafeToForeignPtr0 gradient
|
||||
let (lPtr, _) = U.unsafeToForeignPtr0 lastUpdate
|
||||
|
||||
withForeignPtr wPtr $ \wPtr' ->
|
||||
withForeignPtr gPtr $ \gPtr' ->
|
||||
withForeignPtr lPtr $ \lPtr' ->
|
||||
withForeignPtr outWPtr $ \outWPtr' ->
|
||||
withForeignPtr outMPtr $ \outMPtr' ->
|
||||
decend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'
|
||||
|
||||
return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len)
|
||||
|
||||
foreign import ccall unsafe
|
||||
decend_cpu
|
||||
:: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO ()
|
||||
|
@ -1,10 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Grenade.Layers.Logit (
|
||||
Logit (..)
|
||||
) where
|
||||
@ -27,17 +24,16 @@ instance UpdateLayer Logit where
|
||||
createRandom = return Logit
|
||||
|
||||
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
|
||||
runForwards _ (S1D' y) = S1D' (logistic y)
|
||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
|
||||
runForwards _ (S1D y) = S1D (logistic y)
|
||||
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (logistic' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
|
||||
runForwards _ (S2D' y) = S2D' (logistic y)
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
|
||||
runForwards _ (S2D y) = S2D (logistic y)
|
||||
runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (logistic' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (logistic y)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (logistic' y * dEdy))
|
||||
|
||||
runForwards _ (S3D y) = S3D (logistic y)
|
||||
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (logistic' y * dEdy))
|
||||
|
||||
logistic :: Floating a => a -> a
|
||||
logistic x = 1 / (1 + exp (-x))
|
||||
|
@ -4,10 +4,6 @@
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
|
||||
module Grenade.Layers.Pad (
|
||||
Pad (..)
|
||||
) where
|
||||
@ -50,19 +46,19 @@ instance ( KnownNat padLeft
|
||||
, (inputRows + padTop + padBottom) ~ outputRows
|
||||
, (inputColumns + padLeft + padRight) ~ outputColumns
|
||||
) => Layer (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||
runForwards Pad (S2D' input) =
|
||||
runForwards Pad (S2D input) =
|
||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||
padr = fromIntegral $ natVal (Proxy :: Proxy padRight)
|
||||
padb = fromIntegral $ natVal (Proxy :: Proxy padBottom)
|
||||
m = extract input
|
||||
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
|
||||
in S2D' . fromJust . create $ r
|
||||
runBackwards Pad _ (S2D' dEdy) =
|
||||
in S2D . fromJust . create $ r
|
||||
runBackwards Pad _ (S2D dEdy) =
|
||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
m = extract dEdy
|
||||
vs = subMatrix (padt, padl) (nrows, ncols) m
|
||||
in ((), S2D' . fromJust . create $ vs)
|
||||
in ((), S2D . fromJust . create $ vs)
|
||||
|
@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
@ -6,10 +5,7 @@
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
|
||||
module Grenade.Layers.Pooling (
|
||||
Pooling (..)
|
||||
) where
|
||||
@ -55,7 +51,7 @@ instance ( KnownNat kernelRows
|
||||
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||
runForwards Pooling (S2D' input) =
|
||||
runForwards Pooling (S2D input) =
|
||||
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
@ -65,8 +61,8 @@ instance ( KnownNat kernelRows
|
||||
ex = extract input
|
||||
r = poolForward 1 height width kx ky sx sy ex
|
||||
rs = fromJust . create $ r
|
||||
in S2D' $ rs
|
||||
runBackwards Pooling (S2D' input) (S2D' dEdy) =
|
||||
in S2D $ rs
|
||||
runBackwards Pooling (S2D input) (S2D dEdy) =
|
||||
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
@ -76,7 +72,7 @@ instance ( KnownNat kernelRows
|
||||
ex = extract input
|
||||
eo = extract dEdy
|
||||
vs = poolBackward 1 height width kx ky sx sy ex eo
|
||||
in ((), S2D' . fromJust . create $ vs)
|
||||
in ((), S2D . fromJust . create $ vs)
|
||||
|
||||
|
||||
-- | A three dimensional image can be pooled on each layer.
|
||||
@ -93,7 +89,7 @@ instance ( KnownNat kernelRows
|
||||
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
|
||||
, ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
|
||||
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
|
||||
runForwards Pooling (S3D' input) =
|
||||
runForwards Pooling (S3D input) =
|
||||
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
@ -104,8 +100,8 @@ instance ( KnownNat kernelRows
|
||||
ex = extract input
|
||||
r = poolForward ch ix iy kx ky sx sy ex
|
||||
rs = fromJust . create $ r
|
||||
in S3D' rs
|
||||
runBackwards Pooling (S3D' input) (S3D' dEdy) =
|
||||
in S3D rs
|
||||
runBackwards Pooling (S3D input) (S3D dEdy) =
|
||||
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
@ -116,4 +112,4 @@ instance ( KnownNat kernelRows
|
||||
ex = extract input
|
||||
eo = extract dEdy
|
||||
vs = poolBackward ch ix iy kx ky sx sy ex eo
|
||||
in ((), S3D' . fromJust . create $ vs)
|
||||
in ((), S3D . fromJust . create $ vs)
|
||||
|
@ -1,10 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Grenade.Layers.Relu (
|
||||
Relu (..)
|
||||
) where
|
||||
@ -27,25 +24,25 @@ instance UpdateLayer Relu where
|
||||
createRandom = return Relu
|
||||
|
||||
instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
|
||||
runForwards _ (S1D' y) = S1D' (relu y)
|
||||
runForwards _ (S1D y) = S1D (relu y)
|
||||
where
|
||||
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
|
||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
|
||||
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (relu' y * dEdy))
|
||||
where
|
||||
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
||||
instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
|
||||
runForwards _ (S2D' y) = S2D' (relu y)
|
||||
runForwards _ (S2D y) = S2D (relu y)
|
||||
where
|
||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
|
||||
runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (relu' y * dEdy))
|
||||
where
|
||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (relu y)
|
||||
runForwards _ (S3D y) = S3D (relu y)
|
||||
where
|
||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (relu' y * dEdy))
|
||||
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (relu' y * dEdy))
|
||||
where
|
||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
@ -1,10 +1,7 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
|
||||
module Grenade.Layers.Tanh (
|
||||
Tanh (..)
|
||||
) where
|
||||
@ -24,16 +21,16 @@ instance UpdateLayer Tanh where
|
||||
createRandom = return Tanh
|
||||
|
||||
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
|
||||
runForwards _ (S1D' y) = S1D' (tanh y)
|
||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
|
||||
runForwards _ (S1D y) = S1D (tanh y)
|
||||
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (tanh' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
|
||||
runForwards _ (S2D' y) = S2D' (tanh y)
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
|
||||
runForwards _ (S2D y) = S2D (tanh y)
|
||||
runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (tanh' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (tanh y)
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (tanh' y * dEdy))
|
||||
runForwards _ (S3D y) = S3D (tanh y)
|
||||
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (tanh' y * dEdy))
|
||||
|
||||
tanh' :: (Floating a) => a -> a
|
||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
||||
|
9
src/Grenade/Recurrent.hs
Normal file
9
src/Grenade/Recurrent.hs
Normal file
@ -0,0 +1,9 @@
|
||||
module Grenade.Recurrent (
|
||||
module X
|
||||
) where
|
||||
|
||||
import Grenade.Recurrent.Core.Network as X
|
||||
import Grenade.Recurrent.Core.Runner as X
|
||||
import Grenade.Recurrent.Layers.BasicRecurrent as X
|
||||
import Grenade.Recurrent.Layers.LSTM as X
|
||||
import Grenade.Recurrent.Layers.Trivial as X
|
98
src/Grenade/Recurrent/Core/Network.hs
Normal file
98
src/Grenade/Recurrent/Core/Network.hs
Normal file
@ -0,0 +1,98 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE EmptyDataDecls #-}
|
||||
module Grenade.Recurrent.Core.Network (
|
||||
Recurrent
|
||||
, FeedForward
|
||||
, RecurrentLayer (..)
|
||||
, RecurrentUpdateLayer (..)
|
||||
, RecurrentNetwork (..)
|
||||
, RecurrentInputs (..)
|
||||
, CreatableRecurrent (..)
|
||||
) where
|
||||
|
||||
|
||||
import Control.Monad.Random ( MonadRandom )
|
||||
import Data.Singletons ( SingI )
|
||||
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Network
|
||||
|
||||
|
||||
-- | Witness type to say indicate we're building up with a normal feed
|
||||
-- forward layer.
|
||||
data FeedForward :: * -> *
|
||||
-- | Witness type to say indicate we're building up with a recurrent layer.
|
||||
data Recurrent :: * -> *
|
||||
|
||||
-- | Class for a recurrent layer.
|
||||
-- It's quite similar to a normal layer but for the input and output
|
||||
-- of an extra recurrent data shape.
|
||||
class UpdateLayer x => RecurrentUpdateLayer x where
|
||||
-- | Shape of data that is passed between each subsequent run of the layer
|
||||
type RecurrentShape x :: Shape
|
||||
|
||||
class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
|
||||
-- | Used in training and scoring. Take the input from the previous
|
||||
-- layer, and give the output from this layer.
|
||||
runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (S (RecurrentShape x), S o)
|
||||
-- | Back propagate a step. Takes the current layer, the input that the
|
||||
-- layer gave from the input and the back propagated derivatives from
|
||||
-- the layer above.
|
||||
-- Returns the gradient layer and the derivatives to push back further.
|
||||
runRecurrentBackwards :: x -> S (RecurrentShape x) -> S i -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i)
|
||||
|
||||
data RecurrentNetwork :: [*] -> [Shape] -> * where
|
||||
OR :: (SingI i, SingI o, Layer x i o) => !x -> RecurrentNetwork '[FeedForward x] '[i, o]
|
||||
(:~~>) :: (SingI i, Layer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (FeedForward x ': xs) (i ': h ': hs)
|
||||
(:~@>) :: (SingI i, RecurrentLayer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (Recurrent x ': xs) (i ': h ': hs)
|
||||
infixr 5 :~~>
|
||||
infixr 5 :~@>
|
||||
|
||||
instance Show (RecurrentNetwork l h) where
|
||||
show (OR a) = "OR " ++ show a
|
||||
show (i :~~> o) = show i ++ "\n:~~>\n" ++ show o
|
||||
show (i :~@> o) = show i ++ "\n:~@>\n" ++ show o
|
||||
|
||||
|
||||
-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
|
||||
-- Parameterised on the layers of a Network.
|
||||
data RecurrentInputs :: [*] -> * where
|
||||
ORS :: UpdateLayer x
|
||||
=> () -> RecurrentInputs '[FeedForward x]
|
||||
(:~~+>) :: UpdateLayer x
|
||||
=> () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs)
|
||||
(:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x)
|
||||
=> !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs)
|
||||
infixr 5 :~~+>
|
||||
infixr 5 :~@+>
|
||||
|
||||
-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
|
||||
-- recurrent network and a set of random inputs for it is with the randomRecurrent.
|
||||
class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where
|
||||
-- | Create a network of the types requested
|
||||
randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss, RecurrentInputs xs)
|
||||
|
||||
instance (SingI i, SingI o, Layer x i o) => CreatableRecurrent (FeedForward x ': '[]) (i ': o ': '[]) where
|
||||
randomRecurrent = do
|
||||
thisLayer <- createRandom
|
||||
return (OR thisLayer, ORS ())
|
||||
|
||||
instance (SingI i, Layer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (FeedForward x ': xs) (i ': o ': r ': rs) where
|
||||
randomRecurrent = do
|
||||
thisLayer <- createRandom
|
||||
(rest, resti) <- randomRecurrent
|
||||
return (thisLayer :~~> rest, () :~~+> resti)
|
||||
|
||||
instance (SingI i, RecurrentLayer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (Recurrent x ': xs) (i ': o ': r ': rs) where
|
||||
randomRecurrent = do
|
||||
thisLayer <- createRandom
|
||||
thisShape <- randomOfShape
|
||||
(rest, resti) <- randomRecurrent
|
||||
return (thisLayer :~@> rest, thisShape :~@+> resti)
|
||||
|
144
src/Grenade/Recurrent/Core/Runner.hs
Normal file
144
src/Grenade/Recurrent/Core/Runner.hs
Normal file
@ -0,0 +1,144 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
module Grenade.Recurrent.Core.Runner (
|
||||
trainRecurrent
|
||||
, runRecurrent
|
||||
) where
|
||||
|
||||
import Data.Singletons.Prelude
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
import Grenade.Recurrent.Core.Network
|
||||
|
||||
-- | Drive and network and collect its back propogated gradients.
|
||||
trainRecurrent :: forall shapes layers. SingI (Last shapes)
|
||||
=> LearningParameters
|
||||
-> RecurrentNetwork layers shapes
|
||||
-> RecurrentInputs layers
|
||||
-> [(S (Head shapes), Maybe (S (Last shapes)))]
|
||||
-> (RecurrentNetwork layers shapes, RecurrentInputs layers)
|
||||
trainRecurrent rate network recinputs examples =
|
||||
updateBack $ go inputs network recinputs
|
||||
where
|
||||
inputs = fst <$> examples
|
||||
targets = snd <$> examples
|
||||
updateBack (a,recgrad,_) = (a,updateRecInputs rate recinputs recgrad)
|
||||
|
||||
go :: forall js sublayers. (Last js ~ Last shapes)
|
||||
=> [S (Head js)] -- ^ input vector
|
||||
-> RecurrentNetwork sublayers js -- ^ network to train
|
||||
-> RecurrentInputs sublayers
|
||||
-> (RecurrentNetwork sublayers js, RecurrentInputs sublayers, [S (Head js)])
|
||||
|
||||
-- This is a simple non-recurrent layer, just map it forwards
|
||||
-- Note we're doing training here, we could just return a list of gradients
|
||||
-- (and probably will in future).
|
||||
go !xs (layer :~~> n) (() :~~+> nIn)
|
||||
= let ys = runForwards layer <$> xs
|
||||
-- recursively run the rest of the network, and get the gradients from above.
|
||||
(newFN, ig, grads) = go ys n nIn
|
||||
-- calculate the gradient for this layer to pass down,
|
||||
back = uncurry (runBackwards layer) <$> zip (reverse xs) grads
|
||||
-- the new trained layer.
|
||||
newlayer = runUpdates rate layer (fst <$> back)
|
||||
|
||||
in (newlayer :~~> newFN, () :~~+> ig, snd <$> back)
|
||||
|
||||
-- This is a recurrent layer, so we need to do a scan, first input to last, providing
|
||||
-- the recurrent shape output to the next layer.
|
||||
go !xs (layer :~@> n) (g :~@+> nIn)
|
||||
= let ys = scanlFrom layer g xs
|
||||
|
||||
(newFN, ig, grads) = go (snd <$> ys) n nIn
|
||||
|
||||
backExamples = zip3 (fst <$> reverse ys) (reverse xs) grads
|
||||
|
||||
(rg, back) = myscanbackward layer backExamples
|
||||
-- the new trained layer.
|
||||
newlayer = runUpdates rate layer (fst <$> back)
|
||||
in (newlayer :~@> newFN, rg :~@+> ig, snd <$> back)
|
||||
|
||||
-- Handle the output layer, bouncing the derivatives back down.
|
||||
-- We may not have a target for each example, so when we don't use 0 gradient.
|
||||
go !xs (OR layer) (ORS ())
|
||||
= let ys = runForwards layer <$> xs
|
||||
-- recursively run the rest of the network, and get the gradients from above.
|
||||
back = uncurry (runBackwards layer) <$> zip xs (zipWith makeError ys targets)
|
||||
-- the new trained layer.
|
||||
newlayer = runUpdates rate layer (reverse $ fst <$> back)
|
||||
in (OR newlayer, ORS (), reverse (snd <$> back))
|
||||
|
||||
go _ _ _ =
|
||||
error "Impossible for network and recurrent inputs to have different shapes"
|
||||
|
||||
|
||||
makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes)
|
||||
makeError _ Nothing = 0
|
||||
makeError y (Just t) = y - t
|
||||
|
||||
updateRecInputs :: forall sublayers.
|
||||
LearningParameters
|
||||
-> RecurrentInputs sublayers
|
||||
-> RecurrentInputs sublayers
|
||||
-> RecurrentInputs sublayers
|
||||
|
||||
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
|
||||
= () :~~+> updateRecInputs l xs ys
|
||||
|
||||
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
|
||||
= (x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
|
||||
|
||||
updateRecInputs _ (ORS ()) (ORS ())
|
||||
= ORS ()
|
||||
updateRecInputs _ _ _
|
||||
= error "Impossible for updateRecInputs to have different shapes"
|
||||
|
||||
scanlFrom :: forall x i o. RecurrentLayer x i o
|
||||
=> x -- ^ the layer
|
||||
-> S (RecurrentShape x) -- ^ place to start
|
||||
-> [S i] -- ^ list of inputs to scan through
|
||||
-> [(S (RecurrentShape x), S o)] -- ^ list of scan inputs and outputs
|
||||
scanlFrom !layer !recShape (x:xs) =
|
||||
let (lerec, lepush) = runRecurrentForwards layer recShape x
|
||||
in (recShape, lepush) : scanlFrom layer lerec xs
|
||||
scanlFrom _ _ [] = []
|
||||
|
||||
myscanbackward :: forall x i o. RecurrentLayer x i o
|
||||
=> x -- ^ the layer
|
||||
-> [(S (RecurrentShape x), S i, S o)] -- ^ the list of inputs and output to scan over
|
||||
-> (S (RecurrentShape x), [(Gradient x, S i)]) -- ^ list of gradients to fold and inputs to backprop
|
||||
myscanbackward layer =
|
||||
goX 0
|
||||
where
|
||||
goX :: S (RecurrentShape x) -> [(S (RecurrentShape x), S i, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)])
|
||||
goX !lastback ((recShape, lastin, backgrad):xs) =
|
||||
let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recShape lastin lastback backgrad
|
||||
(pushedback, ll) = goX recgrad xs
|
||||
in (pushedback, (layergrad, ingrad) : ll)
|
||||
goX !lastback [] = (lastback, [])
|
||||
|
||||
-- | Just forwards propagation with no training.
|
||||
runRecurrent :: RecurrentNetwork layers shapes
|
||||
-> RecurrentInputs layers -> S (Head shapes)
|
||||
-> (RecurrentInputs layers, S (Last shapes))
|
||||
runRecurrent (layer :~~> n) (() :~~+> nr) !x
|
||||
= let ys = runForwards layer x
|
||||
(nr', o) = runRecurrent n nr ys
|
||||
in (() :~~+> nr', o)
|
||||
runRecurrent (layer :~@> n) (recin :~@+> nr) !x
|
||||
= let (recin', y) = runRecurrentForwards layer recin x
|
||||
(nr', o) = runRecurrent n nr y
|
||||
in (recin' :~@+> nr', o)
|
||||
runRecurrent (OR layer) (ORS ()) !x
|
||||
= (ORS (), runForwards layer x)
|
||||
|
||||
runRecurrent _ _ _
|
||||
= error "Impossible for the gradients of a network to have a different length or shape to the network"
|
92
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
Normal file
92
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
Normal file
@ -0,0 +1,92 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
module Grenade.Recurrent.Layers.BasicRecurrent (
|
||||
BasicRecurrent (..)
|
||||
, randomBasicRecurrent
|
||||
) where
|
||||
|
||||
import Control.Monad.Random ( MonadRandom, getRandom )
|
||||
|
||||
import Data.Singletons.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
import GHC.TypeLits
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Recurrent.Core.Network
|
||||
|
||||
data BasicRecurrent :: Nat -- Input layer size
|
||||
-> Nat -- Output layer size
|
||||
-> * where
|
||||
BasicRecurrent :: ( KnownNat input
|
||||
, KnownNat output
|
||||
, KnownNat matrixCols
|
||||
, matrixCols ~ (input + output))
|
||||
=> !(R output) -- Bias neuron weights
|
||||
-> !(R output) -- Bias neuron momentum
|
||||
-> !(L output matrixCols) -- Activation
|
||||
-> !(L output matrixCols) -- Momentum
|
||||
-> BasicRecurrent input output
|
||||
|
||||
data BasicRecurrent' :: Nat -- Input layer size
|
||||
-> Nat -- Output layer size
|
||||
-> * where
|
||||
BasicRecurrent' :: ( KnownNat input
|
||||
, KnownNat output
|
||||
, KnownNat matrixCols
|
||||
, matrixCols ~ (input + output))
|
||||
=> !(R output) -- Bias neuron gradients
|
||||
-> !(L output matrixCols)
|
||||
-> BasicRecurrent' input output
|
||||
|
||||
instance Show (BasicRecurrent i o) where
|
||||
show BasicRecurrent {} = "BasicRecurrent"
|
||||
|
||||
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
|
||||
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
|
||||
|
||||
runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
|
||||
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
|
||||
newBias = oldBias + newBiasMomentum
|
||||
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
|
||||
regulariser = konst (learningRegulariser * learningRate) * oldActivations
|
||||
newActivations = oldActivations + newMomentum - regulariser
|
||||
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum
|
||||
|
||||
createRandom = randomBasicRecurrent
|
||||
|
||||
instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentUpdateLayer (BasicRecurrent i o) where
|
||||
type RecurrentShape (BasicRecurrent i o) = 'D1 o
|
||||
|
||||
instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentLayer (BasicRecurrent i o) ('D1 i) ('D1 o) where
|
||||
-- Do a matrix vector multiplication and return the result.
|
||||
runRecurrentForwards (BasicRecurrent wB _ wN _) (S1D lastOutput) (S1D thisInput) =
|
||||
let thisOutput = S1D $ wB + wN #> (thisInput # lastOutput)
|
||||
in (thisOutput, thisOutput)
|
||||
|
||||
-- Run a backpropogation step for a full connected layer.
|
||||
runRecurrentBackwards (BasicRecurrent _ _ wN _) (S1D lastOutput) (S1D thisInput) (S1D dRec) (S1D dEdy) =
|
||||
let biasGradient = (dRec + dEdy)
|
||||
layerGrad = (dRec + dEdy) `outer` (thisInput # lastOutput)
|
||||
-- calcluate derivatives for next step
|
||||
(backGrad, recGrad) = split $ tr wN #> (dRec + dEdy)
|
||||
in (BasicRecurrent' biasGradient layerGrad, S1D recGrad, S1D backGrad)
|
||||
|
||||
randomBasicRecurrent :: (MonadRandom m, KnownNat i, KnownNat o, KnownNat x, x ~ (i + o))
|
||||
=> m (BasicRecurrent i o)
|
||||
randomBasicRecurrent = do
|
||||
seed1 <- getRandom
|
||||
seed2 <- getRandom
|
||||
let wB = randomVector seed1 Uniform * 2 - 1
|
||||
wN = uniformSample seed2 (-1) 1
|
||||
bm = konst 0
|
||||
mm = konst 0
|
||||
return $ BasicRecurrent wB bm wN mm
|
244
src/Grenade/Recurrent/Layers/LSTM.hs
Normal file
244
src/Grenade/Recurrent/Layers/LSTM.hs
Normal file
@ -0,0 +1,244 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
module Grenade.Recurrent.Layers.LSTM (
|
||||
LSTM (..)
|
||||
, LSTMWeights (..)
|
||||
, randomLSTM
|
||||
) where
|
||||
|
||||
import Control.Monad.Random ( MonadRandom, getRandom )
|
||||
|
||||
-- import Data.List ( foldl1' )
|
||||
import Data.Singletons.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
import Grenade.Layers.Internal.Update
|
||||
|
||||
import Grenade.Recurrent.Core.Network
|
||||
|
||||
-- | Long Short Term Memory Recurrent unit
|
||||
--
|
||||
-- This is a Peephole formulation, so the recurrent shape is
|
||||
-- just the cell state, the previous output is not held or used
|
||||
-- at all.
|
||||
data LSTM :: Nat -> Nat -> * where
|
||||
LSTM :: ( KnownNat input
|
||||
, KnownNat output
|
||||
) => !(LSTMWeights input output) -- Weights
|
||||
-> !(LSTMWeights input output) -- Momentums
|
||||
-> LSTM input output
|
||||
|
||||
data LSTMWeights :: Nat -> Nat -> * where
|
||||
LSTMWeights :: ( KnownNat input
|
||||
, KnownNat output
|
||||
) => {
|
||||
lstmWf :: !(L output input) -- Weight Forget (W_f)
|
||||
, lstmUf :: !(L output output) -- Cell State Forget (U_f)
|
||||
, lstmBf :: !(R output) -- Bias Forget (b_f)
|
||||
, lstmWi :: !(L output input) -- Weight Input (W_i)
|
||||
, lstmUi :: !(L output output) -- Cell State Input (U_i)
|
||||
, lstmBi :: !(R output) -- Bias Input (b_i)
|
||||
, lstmWo :: !(L output input) -- Weight Output (W_o)
|
||||
, lstmUo :: !(L output output) -- Cell State Output (U_o)
|
||||
, lstmBo :: !(R output) -- Bias Output (b_o)
|
||||
, lstmWc :: !(L output input) -- Weight Cell (W_c)
|
||||
, lstmBc :: !(R output) -- Bias Cell (b_c)
|
||||
} -> LSTMWeights input output
|
||||
|
||||
instance Show (LSTM i o) where
|
||||
show LSTM {} = "LSTM"
|
||||
|
||||
instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
|
||||
-- The gradients are the same shape as the weights and momentum
|
||||
-- This seems to be a general pattern, maybe it should be enforced.
|
||||
type Gradient (LSTM i o) = (LSTMWeights i o)
|
||||
|
||||
-- Run the update function for each group matrix/vector of weights, momentums and gradients.
|
||||
-- Hmm, maybe the function should be used instead of passing in the learning parameters.
|
||||
runUpdate LearningParameters {..} (LSTM w m) g =
|
||||
let (wf, wf') = u lstmWf w m g
|
||||
(uf, uf') = u lstmUf w m g
|
||||
(bf, bf') = v lstmBf w m g
|
||||
(wi, wi') = u lstmWi w m g
|
||||
(ui, ui') = u lstmUi w m g
|
||||
(bi, bi') = v lstmBi w m g
|
||||
(wo, wo') = u lstmWo w m g
|
||||
(uo, uo') = u lstmUo w m g
|
||||
(bo, bo') = v lstmBo w m g
|
||||
(wc, wc') = u lstmWc w m g
|
||||
(bc, bc') = v lstmBc w m g
|
||||
in LSTM (LSTMWeights wf uf bf wi ui bi wo uo bo wc bc) (LSTMWeights wf' uf' bf' wi' ui' bi' wo' uo' bo' wc' bc')
|
||||
where
|
||||
-- Utility function for updating with the momentum, gradients, and weights.
|
||||
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
|
||||
u e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||
decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
|
||||
|
||||
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
|
||||
v e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||
decendVector learningRate learningMomentum learningRegulariser weights gradient momentum
|
||||
|
||||
-- There's a lot of updates here, so to try and minimise the number of data copies
|
||||
-- we'll create a mutable bucket for each.
|
||||
-- runUpdates rate lstm gs =
|
||||
-- let combinedGradient = foldl1' uu gs
|
||||
-- in runUpdate rate lstm combinedGradient
|
||||
-- where
|
||||
-- uu :: (KnownNat i, KnownNat o) => LSTMWeights i o -> LSTMWeights i o -> LSTMWeights i o
|
||||
-- uu a b =
|
||||
-- let wf = u lstmWf a b
|
||||
-- uf = u lstmUf a b
|
||||
-- bf = v lstmBf a b
|
||||
-- wi = u lstmWi a b
|
||||
-- ui = u lstmUi a b
|
||||
-- bi = v lstmBi a b
|
||||
-- wo = u lstmWo a b
|
||||
-- uo = u lstmUo a b
|
||||
-- bo = v lstmBo a b
|
||||
-- wc = u lstmWc a b
|
||||
-- bc = v lstmBc a b
|
||||
-- in LSTMWeights wf uf bf wi ui bi wo uo bo wc bc
|
||||
-- u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> L out ix
|
||||
-- u e (e -> a) (e -> b) = tr $ tr a + tr b
|
||||
|
||||
-- v :: forall x ix. (x -> (R ix)) -> x -> x -> R ix
|
||||
-- v e (e -> a) (e -> b) = a + b
|
||||
|
||||
createRandom = randomLSTM
|
||||
|
||||
instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
|
||||
-- The recurrent shape is the same size as the output.
|
||||
-- It's actually the cell state however, as this is a peephole variety LSTM.
|
||||
type RecurrentShape (LSTM i o) = 'D1 o
|
||||
|
||||
instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
|
||||
-- Forward propagation for the LSTM layer.
|
||||
-- The size of the cell state is also the size of the output.
|
||||
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
|
||||
let -- Forget state vector
|
||||
f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
|
||||
-- Input state vector
|
||||
i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
|
||||
-- Output state vector
|
||||
o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
|
||||
-- Cell input state vector
|
||||
c_x = tanh $ lstmBc + lstmWc #> input
|
||||
-- Cell state
|
||||
c_t = f_t * cell + i_t * c_x
|
||||
-- Output (it's sometimes recommended to use tanh c_t)
|
||||
h_t = o_t * c_t
|
||||
in (S1D c_t, S1D h_t)
|
||||
|
||||
-- Run a backpropogation step for an LSTM layer.
|
||||
-- We're doing all the derivatives by hand here, so one should
|
||||
-- be extra careful when changing this.
|
||||
--
|
||||
-- There's a test version using the AD library without hmatrix in the test
|
||||
-- suite. These should match always.
|
||||
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) (S1D cellGrad) (S1D h_t') =
|
||||
-- We're not keeping the Wengert tape during the forward pass,
|
||||
-- so we're duplicating some work here.
|
||||
--
|
||||
-- If I was being generous, I'd call it checkpointing.
|
||||
--
|
||||
-- Maybe think about better ways to store some intermediate states.
|
||||
let -- Forget state vector
|
||||
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
|
||||
f_t = sigmoid f_s
|
||||
-- Input state vector
|
||||
i_s = lstmBi + lstmWi #> input + lstmUi #> cell
|
||||
i_t = sigmoid i_s
|
||||
-- Output state vector
|
||||
o_s = lstmBo + lstmWo #> input + lstmUo #> cell
|
||||
o_t = sigmoid o_s
|
||||
-- Cell input state vector
|
||||
c_s = lstmBc + lstmWc #> input
|
||||
c_x = tanh c_s
|
||||
-- Cell state
|
||||
c_t = f_t * cell + i_t * c_x
|
||||
|
||||
-- Reverse Mode AD Derivitives
|
||||
c_t' = h_t' * o_t + cellGrad
|
||||
|
||||
f_t' = c_t' * cell
|
||||
f_s' = sigmoid' f_s * f_t'
|
||||
|
||||
o_t' = h_t' * c_t
|
||||
o_s' = sigmoid' o_s * o_t'
|
||||
|
||||
i_t' = c_t' * c_x
|
||||
i_s' = sigmoid' i_s * i_t'
|
||||
|
||||
c_x' = c_t' * i_t
|
||||
c_s' = tanh' c_s * c_x'
|
||||
|
||||
-- The derivatives to pass sideways (recurrent) and downwards
|
||||
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
|
||||
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'
|
||||
|
||||
-- Calculate the gradient Matricies for the input
|
||||
lstmWf' = f_s' `outer` input
|
||||
lstmWi' = i_s' `outer` input
|
||||
lstmWo' = o_s' `outer` input
|
||||
lstmWc' = c_s' `outer` input
|
||||
|
||||
-- Calculate the gradient Matricies for the cell
|
||||
lstmUf' = f_s' `outer` cell
|
||||
lstmUi' = i_s' `outer` cell
|
||||
lstmUo' = o_s' `outer` cell
|
||||
|
||||
-- The biases just get the values, but we'll write it so it's obvious
|
||||
lstmBf' = f_s'
|
||||
lstmBi' = i_s'
|
||||
lstmBo' = o_s'
|
||||
lstmBc' = c_s'
|
||||
|
||||
gradients = LSTMWeights lstmWf' lstmUf' lstmBf' lstmWi' lstmUi' lstmBi' lstmWo' lstmUo' lstmBo' lstmWc' lstmBc'
|
||||
in (gradients, S1D cell', S1D input')
|
||||
|
||||
-- | Generate an LSTM layer with random Weights
|
||||
-- one can also just call createRandom from UpdateLayer
|
||||
--
|
||||
-- Has forget gate biases set to 1 to encourage early learning.
|
||||
--
|
||||
-- https://github.com/karpathy/char-rnn/commit/0dfeaa454e687dd0278f036552ea1e48a0a408c9
|
||||
--
|
||||
randomLSTM :: forall m i o. (MonadRandom m, KnownNat i, KnownNat o)
|
||||
=> m (LSTM i o)
|
||||
randomLSTM = do
|
||||
let w = (\s -> uniformSample s (-1) 1 ) <$> getRandom
|
||||
u = (\s -> uniformSample s (-1) 1 ) <$> getRandom
|
||||
v = (\s -> randomVector s Uniform * 2 - 1) <$> getRandom
|
||||
|
||||
w0 = konst 0
|
||||
u0 = konst 0
|
||||
v0 = konst 0
|
||||
|
||||
LSTM <$> (LSTMWeights <$> w <*> u <*> pure (konst 1) <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
|
||||
<*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
||||
|
||||
-- | Maths
|
||||
--
|
||||
-- TODO: move to not here
|
||||
sigmoid :: Floating a => a -> a
|
||||
sigmoid x = 1 / (1 + exp (-x))
|
||||
|
||||
sigmoid' :: Floating a => a -> a
|
||||
sigmoid' x = logix * (1 - logix)
|
||||
where
|
||||
logix = sigmoid x
|
||||
|
||||
tanh' :: (Floating a) => a -> a
|
||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
23
src/Grenade/Recurrent/Layers/Trivial.hs
Normal file
23
src/Grenade/Recurrent/Layers/Trivial.hs
Normal file
@ -0,0 +1,23 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
module Grenade.Recurrent.Layers.Trivial (
|
||||
Trivial (..)
|
||||
) where
|
||||
|
||||
import Grenade.Core.Network
|
||||
|
||||
-- | A trivial layer.
|
||||
data Trivial = Trivial
|
||||
deriving Show
|
||||
|
||||
instance UpdateLayer Trivial where
|
||||
type Gradient Trivial = ()
|
||||
runUpdate _ _ _ = Trivial
|
||||
createRandom = return Trivial
|
||||
|
||||
instance (a ~ b) => Layer Trivial a b where
|
||||
runForwards _ = id
|
||||
runBackwards _ _ y = ((), y)
|
84
src/Grenade/Utils/OneHot.hs
Normal file
84
src/Grenade/Utils/OneHot.hs
Normal file
@ -0,0 +1,84 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
module Grenade.Utils.OneHot (
|
||||
oneHot
|
||||
, hotMap
|
||||
, makeHot
|
||||
, unHot
|
||||
) where
|
||||
|
||||
import Data.List ( group, sort )
|
||||
|
||||
import Data.Map ( Map )
|
||||
import qualified Data.Map as M
|
||||
|
||||
import Data.Proxy
|
||||
import Data.Singletons.TypeLits
|
||||
|
||||
import Data.Vector ( Vector )
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import Numeric.LinearAlgebra ( maxIndex )
|
||||
import Numeric.LinearAlgebra.Devel
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
import Grenade.Core.Shape
|
||||
|
||||
-- | From an int which is hot, create a 1D Shape
|
||||
-- with one index hot (1) with the rest 0.
|
||||
-- Rerurns Nothing if the hot number is larger
|
||||
-- than the length of the vector.
|
||||
oneHot :: forall n. (KnownNat n)
|
||||
=> Int -> Maybe (S ('D1 n))
|
||||
oneHot hot =
|
||||
let len = fromIntegral $ natVal (Proxy :: Proxy n)
|
||||
in if hot < len
|
||||
then
|
||||
fmap S1D . create $ runSTVector $ do
|
||||
vec <- newVector 0 len
|
||||
writeVector vec hot 1
|
||||
return vec
|
||||
else Nothing
|
||||
|
||||
-- | Create a one hot map from any enumerable.
|
||||
-- Returns a map, and the ordered list for the reverse transformation
|
||||
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a)
|
||||
hotMap n as =
|
||||
let len = fromIntegral $ natVal n
|
||||
uniq = [ c | (c:_) <- group $ sort as]
|
||||
hotl = length uniq
|
||||
in if hotl <= len
|
||||
then
|
||||
Just (M.fromList $ zip uniq [0..], V.fromList uniq)
|
||||
else Nothing
|
||||
|
||||
-- | From a map and value, create a 1D Shape
|
||||
-- with one index hot (1) with the rest 0.
|
||||
-- Rerurns Nothing if the hot number is larger
|
||||
-- than the length of the vector or the map
|
||||
-- doesn't contain the value.
|
||||
makeHot :: forall a n. (Ord a, KnownNat n)
|
||||
=> Map a Int -> a -> Maybe (S ('D1 n))
|
||||
makeHot m x = do
|
||||
hot <- M.lookup x m
|
||||
let len = fromIntegral $ natVal (Proxy :: Proxy n)
|
||||
if hot < len
|
||||
then
|
||||
fmap S1D . create $ runSTVector $ do
|
||||
vec <- newVector 0 len
|
||||
writeVector vec hot 1
|
||||
return vec
|
||||
else Nothing
|
||||
|
||||
unHot :: forall a n. (KnownNat n)
|
||||
=> Vector a -> (S ('D1 n)) -> Maybe a
|
||||
unHot v (S1D xs)
|
||||
= (V.!?) v
|
||||
$ maxIndex (extract xs)
|
||||
|
@ -1,11 +1,10 @@
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Layers.Convolution where
|
||||
|
||||
@ -30,6 +29,17 @@ data OpaqueConvolution :: * where
|
||||
instance Show OpaqueConvolution where
|
||||
show (OpaqueConvolution n) = show n
|
||||
|
||||
genConvolution :: ( KnownNat channels
|
||||
, KnownNat filters
|
||||
, KnownNat kernelRows
|
||||
, KnownNat kernelColumns
|
||||
, KnownNat strideRows
|
||||
, KnownNat strideColumns
|
||||
, KnownNat kernelFlattened
|
||||
, kernelFlattened ~ (kernelRows * kernelColumns * channels)
|
||||
) => Jack (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||
genConvolution = Convolution <$> uniformSample <*> uniformSample
|
||||
|
||||
genOpaqueOpaqueConvolution :: Jack OpaqueConvolution
|
||||
genOpaqueOpaqueConvolution = do
|
||||
Just channels <- someNatVal <$> choose (1, 10)
|
||||
@ -46,7 +56,7 @@ genOpaqueOpaqueConvolution = do
|
||||
p2 = natDict pkc
|
||||
p3 = natDict pch
|
||||
in case p1 %* p2 %* p3 of
|
||||
Dict -> OpaqueConvolution <$> (Convolution <$> uniformSample <*> uniformSample :: Jack (Convolution ch fl kr kc sr sc))
|
||||
Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc))
|
||||
|
||||
prop_conv_net_witness =
|
||||
gamble genOpaqueOpaqueConvolution $ \onet ->
|
||||
@ -80,9 +90,9 @@ prop_conv_net =
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
|
||||
(Dict, Dict, Dict, Dict) ->
|
||||
gamble (S3D' <$> uniformSample) $ \(input :: S' ('D3 inRows inCols channels)) ->
|
||||
let output :: S' ('D3 outRows outCols filters) = runForwards convLayer input
|
||||
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S' ('D3 inRows inCols channels))
|
||||
gamble (S3D <$> uniformSample) $ \(input :: S ('D3 inRows inCols channels)) ->
|
||||
let output :: S ('D3 outRows outCols filters) = runForwards convLayer input
|
||||
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels))
|
||||
= runBackwards convLayer input output
|
||||
in backed `seq` True
|
||||
) :: Property
|
||||
|
@ -44,10 +44,10 @@ genOpaqueFullyConnected = do
|
||||
prop_fully_connected_forwards :: Property
|
||||
prop_fully_connected_forwards =
|
||||
gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) ->
|
||||
gamble (S1D' <$> randomVector) $ \(input :: S' ('D1 i)) ->
|
||||
let output :: S' ('D1 o) = runForwards fclayer input
|
||||
backed :: (Gradient (FullyConnected i o), S' ('D1 i))
|
||||
= runBackwards fclayer input output
|
||||
gamble (S1D <$> randomVector) $ \(input :: S ('D1 i)) ->
|
||||
let output :: S ('D1 o) = runForwards fclayer input
|
||||
backed :: (Gradient (FullyConnected i o), S ('D1 i))
|
||||
= runBackwards fclayer input output
|
||||
in backed `seq` True
|
||||
|
||||
return []
|
||||
|
@ -1,8 +1,8 @@
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Layers.Pooling where
|
||||
|
||||
|
101
test/Test/Grenade/Recurrent/Layers/LSTM.hs
Normal file
101
test/Test/Grenade/Recurrent/Layers/LSTM.hs
Normal file
@ -0,0 +1,101 @@
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Recurrent.Layers.LSTM where
|
||||
|
||||
import Disorder.Jack
|
||||
|
||||
import Data.Foldable ( toList )
|
||||
import Data.Singletons.TypeLits
|
||||
|
||||
import Grenade
|
||||
import Grenade.Recurrent
|
||||
|
||||
import qualified Numeric.LinearAlgebra as H
|
||||
import qualified Numeric.LinearAlgebra.Static as S
|
||||
|
||||
|
||||
import qualified Test.Grenade.Recurrent.Layers.LSTM.Reference as Reference
|
||||
import Test.Jack.Hmatrix
|
||||
|
||||
genLSTM :: forall i o. (KnownNat i, KnownNat o) => Jack (LSTM i o)
|
||||
genLSTM = do
|
||||
let w = uniformSample
|
||||
u = uniformSample
|
||||
v = randomVector
|
||||
|
||||
w0 = S.konst 0
|
||||
u0 = S.konst 0
|
||||
v0 = S.konst 0
|
||||
|
||||
LSTM <$> (LSTMWeights <$> w <*> u <*> v <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
|
||||
<*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
||||
|
||||
prop_lstm_reference_forwards =
|
||||
gamble randomVector $ \(input :: S.R 3) ->
|
||||
gamble randomVector $ \(cell :: S.R 2) ->
|
||||
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||
let actual = runRecurrentForwards net (S1D cell) (S1D input)
|
||||
in case actual of
|
||||
((S1D cellOut) :: S ('D1 2), (S1D output) :: S ('D1 2)) ->
|
||||
let cellOut' = Reference.Vector . H.toList . S.extract $ cellOut
|
||||
output' = Reference.Vector . H.toList . S.extract $ output
|
||||
refNet = Reference.lstmToReference lstmWeights
|
||||
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||
(refCO, refO) = Reference.runLSTM refNet refCell refInput
|
||||
in toList refCO ~~~ toList cellOut' .&&. toList refO ~~~ toList output'
|
||||
|
||||
prop_lstm_reference_backwards =
|
||||
gamble randomVector $ \(input :: S.R 3) ->
|
||||
gamble randomVector $ \(cell :: S.R 2) ->
|
||||
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||
let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
in case actualBacks of
|
||||
(actualGradients, _, _) ->
|
||||
let refNet = Reference.lstmToReference lstmWeights
|
||||
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||
refGradients = Reference.runLSTMback refCell refInput refNet
|
||||
in toList refGradients ~~~ toList (Reference.lstmToReference actualGradients)
|
||||
|
||||
prop_lstm_reference_backwards_input =
|
||||
gamble randomVector $ \(input :: S.R 3) ->
|
||||
gamble randomVector $ \(cell :: S.R 2) ->
|
||||
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||
let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
in case actualBacks of
|
||||
(_, _, S1D actualGradients) ->
|
||||
let refNet = Reference.lstmToReference lstmWeights
|
||||
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||
refGradients = Reference.runLSTMbackOnInput refCell refNet refInput
|
||||
in toList refGradients ~~~ H.toList (S.extract actualGradients)
|
||||
|
||||
prop_lstm_reference_backwards_cell =
|
||||
gamble randomVector $ \(input :: S.R 3) ->
|
||||
gamble randomVector $ \(cell :: S.R 2) ->
|
||||
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
|
||||
let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
in case actualBacks of
|
||||
(_, S1D actualGradients, _) ->
|
||||
let refNet = Reference.lstmToReference lstmWeights
|
||||
refCell = Reference.Vector . H.toList . S.extract $ cell
|
||||
refInput = Reference.Vector . H.toList . S.extract $ input
|
||||
refGradients = Reference.runLSTMbackOnCell refInput refNet refCell
|
||||
in toList refGradients ~~~ (H.toList . S.extract $ actualGradients)
|
||||
|
||||
|
||||
(~~~) as bs = all (< 1e-8) (zipWith (-) as bs)
|
||||
infix 4 ~~~
|
||||
|
||||
return []
|
||||
tests :: IO Bool
|
||||
tests = $quickCheckAll
|
149
test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs
Normal file
149
test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs
Normal file
@ -0,0 +1,149 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE DeriveFunctor #-}
|
||||
{-# LANGUAGE DeriveFoldable #-}
|
||||
{-# LANGUAGE DeriveTraversable #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Recurrent.Layers.LSTM.Reference where
|
||||
|
||||
import Data.Reflection
|
||||
import Numeric.AD.Mode.Reverse
|
||||
import Numeric.AD.Internal.Reverse ( Tape )
|
||||
|
||||
import qualified Grenade.Recurrent.Layers.LSTM as LSTM
|
||||
import qualified Numeric.LinearAlgebra.Static as S
|
||||
import qualified Numeric.LinearAlgebra as H
|
||||
|
||||
--
|
||||
-- This module contains a set of list only versions of
|
||||
-- an LSTM layer which can be used with the AD library.
|
||||
--
|
||||
-- Using this, we can check to make sure that our fast
|
||||
-- back propagation implementation is correct.
|
||||
--
|
||||
|
||||
-- | List only matrix deriving functor
|
||||
data Matrix a = Matrix {
|
||||
matrixWeights :: [[a]]
|
||||
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||
|
||||
-- | List only vector deriving functor
|
||||
data Vector a = Vector {
|
||||
vectorWeights :: [a]
|
||||
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||
|
||||
-- | List only LSTM weights
|
||||
data RefLSTM a = RefLSTM
|
||||
{ refLstmWf :: Matrix a -- Weight Forget (W_f)
|
||||
, refLstmUf :: Matrix a -- Cell State Forget (U_f)
|
||||
, refLstmBf :: Vector a -- Bias Forget (b_f)
|
||||
, refLstmWi :: Matrix a -- Weight Input (W_i)
|
||||
, refLstmUi :: Matrix a -- Cell State Input (U_i)
|
||||
, refLstmBi :: Vector a -- Bias Input (b_i)
|
||||
, refLstmWo :: Matrix a -- Weight Output (W_o)
|
||||
, refLstmUo :: Matrix a -- Cell State Output (U_o)
|
||||
, refLstmBo :: Vector a -- Bias Output (b_o)
|
||||
, refLstmWc :: Matrix a -- Weight Cell (W_c)
|
||||
, refLstmBc :: Vector a -- Bias Cell (b_c)
|
||||
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||
|
||||
lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double
|
||||
lstmToReference LSTM.LSTMWeights {..} =
|
||||
let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f)
|
||||
refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f)
|
||||
refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f)
|
||||
refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i)
|
||||
refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i)
|
||||
refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i)
|
||||
refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o)
|
||||
refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o)
|
||||
refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o)
|
||||
refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c)
|
||||
refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c)
|
||||
in RefLSTM {..}
|
||||
|
||||
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
|
||||
runLSTM RefLSTM {..} cell input =
|
||||
let -- Forget state vector
|
||||
f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell
|
||||
-- Input state vector
|
||||
i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell
|
||||
-- Output state vector
|
||||
o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell
|
||||
-- Cell input state vector
|
||||
c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input
|
||||
-- Cell state
|
||||
c_t = f_t #* cell #+ i_t #* c_x
|
||||
-- Output (it's sometimes recommended to use tanh c_t)
|
||||
h_t = o_t #* c_t
|
||||
in (c_t, h_t)
|
||||
|
||||
runLSTMback :: forall a. Floating a => Vector a -> Vector a -> RefLSTM a -> RefLSTM a
|
||||
runLSTMback cell input =
|
||||
grad f
|
||||
where
|
||||
f :: forall s. Reifies s Tape => RefLSTM (Reverse s a) -> Reverse s a
|
||||
f net =
|
||||
let cell' = fmap auto cell
|
||||
input' = fmap auto input
|
||||
(cells, forwarded) = runLSTM net cell' input'
|
||||
in sum forwarded + sum cells
|
||||
|
||||
runLSTMbackOnInput :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a
|
||||
runLSTMbackOnInput cell net =
|
||||
grad f
|
||||
where
|
||||
f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a
|
||||
f input =
|
||||
let cell' = fmap auto cell
|
||||
net' = fmap auto net
|
||||
(cells, forwarded) = runLSTM net' cell' input
|
||||
in sum forwarded + sum cells
|
||||
|
||||
runLSTMbackOnCell :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a
|
||||
runLSTMbackOnCell input net =
|
||||
grad f
|
||||
where
|
||||
f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a
|
||||
f cell =
|
||||
let input' = fmap auto input
|
||||
net' = fmap auto net
|
||||
(cells, forwarded) = runLSTM net' cell input'
|
||||
in sum forwarded + sum cells
|
||||
|
||||
-- | Helper to multiply a matrix by a vector
|
||||
matMult :: Num a => Matrix a -> Vector a -> Vector a
|
||||
matMult (Matrix m) (Vector v) = Vector result
|
||||
where
|
||||
lrs = map length m
|
||||
l = length v
|
||||
result = if all (== l) lrs
|
||||
then map (\r -> sum $ zipWith (*) r v) m
|
||||
else error $ "Matrix has rows of length " ++ show lrs ++
|
||||
" but vector is of length " ++ show l
|
||||
|
||||
(#>) :: Num a => Matrix a -> Vector a -> Vector a
|
||||
(#>) = matMult
|
||||
infixr 8 #>
|
||||
|
||||
(#+) :: Num a => Vector a -> Vector a -> Vector a
|
||||
(#+) (Vector as) (Vector bs) = Vector $ zipWith (+) as bs
|
||||
infixl 6 #+
|
||||
|
||||
(#-) :: Num a => Vector a -> Vector a -> Vector a
|
||||
(#-) (Vector as) (Vector bs) = Vector $ zipWith (-) as bs
|
||||
infixl 6 #-
|
||||
|
||||
(#*) :: Num a => Vector a -> Vector a -> Vector a
|
||||
(#*) (Vector as) (Vector bs) = Vector $ zipWith (*) as bs
|
||||
infixl 7 #*
|
||||
|
||||
sigmoid :: (Functor f, Floating a) => f a -> f a
|
||||
sigmoid xs = (\x -> 1 / (1 + exp (-x))) <$> xs
|
@ -4,7 +4,6 @@
|
||||
|
||||
module Test.Jack.Hmatrix where
|
||||
|
||||
import Data.Proxy
|
||||
import Disorder.Jack
|
||||
|
||||
import GHC.TypeLits
|
||||
@ -12,9 +11,7 @@ import GHC.TypeLits
|
||||
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||
|
||||
randomVector :: forall n. KnownNat n => Jack (HStatic.R n)
|
||||
randomVector = HStatic.fromList <$> vectorOf (fromInteger (natVal (Proxy :: Proxy n))) sizedRealFrac
|
||||
randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> sizedNat
|
||||
|
||||
uniformSample :: forall m n. (KnownNat m, KnownNat n) => Jack (HStatic.L m n)
|
||||
uniformSample = HStatic.fromList
|
||||
<$> vectorOf (fromInteger (natVal (Proxy :: Proxy m)) * fromInteger (natVal (Proxy :: Proxy n)))
|
||||
sizedRealFrac
|
||||
uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> sizedNat
|
||||
|
13
test/test.hs
13
test/test.hs
@ -1,12 +1,13 @@
|
||||
import Disorder.Core.Main
|
||||
|
||||
import qualified Test.Grenade.Layers.Pooling as Test.Grenade.Layers.Pooling
|
||||
import qualified Test.Grenade.Layers.Convolution as Test.Grenade.Layers.Convolution
|
||||
import qualified Test.Grenade.Layers.FullyConnected as Test.Grenade.Layers.FullyConnected
|
||||
import qualified Test.Grenade.Layers.Pooling
|
||||
import qualified Test.Grenade.Layers.Convolution
|
||||
import qualified Test.Grenade.Layers.FullyConnected
|
||||
|
||||
import qualified Test.Grenade.Layers.Internal.Convolution as Test.Grenade.Layers.Internal.Convolution
|
||||
import qualified Test.Grenade.Layers.Internal.Pooling as Test.Grenade.Layers.Internal.Pooling
|
||||
import qualified Test.Grenade.Layers.Internal.Convolution
|
||||
import qualified Test.Grenade.Layers.Internal.Pooling
|
||||
|
||||
import qualified Test.Grenade.Recurrent.Layers.LSTM
|
||||
|
||||
main :: IO ()
|
||||
main =
|
||||
@ -17,4 +18,6 @@ main =
|
||||
|
||||
, Test.Grenade.Layers.Internal.Convolution.tests
|
||||
, Test.Grenade.Layers.Internal.Pooling.tests
|
||||
|
||||
, Test.Grenade.Recurrent.Layers.LSTM.tests
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user