Remove use of RecordWildCards

This commit is contained in:
Erik de Castro Lopo 2020-04-11 17:18:23 +10:00
parent f19da0f74d
commit 5206c95c42
8 changed files with 91 additions and 104 deletions

View File

@ -1,6 +1,5 @@
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
@ -83,34 +82,34 @@ loadShakespeare path = do
return (V.fromList hot, m, cs) return (V.fromList hot, m, cs)
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian) trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
trainSlice !rate !net !recIns input offset size = trainSlice !lrate !net !recIns input offset size =
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
in case reverse e of in case reverse e of
(o : l : xs) -> (o : l : xs) ->
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs) let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
in trainRecurrent rate net recIns examples in trainRecurrent lrate net recIns examples
_ -> error "Not enough input" _ -> error "Not enough input"
where where
x = fromMaybe (error "Hot variable didn't fit.") x = fromMaybe (error "Hot variable didn't fit.")
runShakespeare :: ShakespeareOpts -> ExceptT String IO () runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
runShakespeare ShakespeareOpts {..} = do runShakespeare opts = do
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile (shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare $ trainingFile opts
(net0, i0) <- lift $ (net0, i0) <- lift $
case loadPath of case loadPath opts of
Just loadFile -> netLoad loadFile Just loadFile -> netLoad loadFile
Nothing -> (,0) <$> randomNet Nothing -> (,0) <$> randomNet
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do (trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
xs <- take (iterations `div` 10) <$> getRandomRs (0, length shakespeare - size - 1) xs <- take (iterations opts `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice (rate opts) n i shakespeare offset size) (net, io) xs
results <- take 1000 <$> generateParagraph trained bestInput temperature oneHotMap oneHotDictionary ( S1D $ konst 0) results <- take 1000 <$> generateParagraph trained bestInput (temperature opts) oneHotMap oneHotDictionary ( S1D $ konst 0)
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size) putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
putStrLn (unAnnotateCapitals results) putStrLn (unAnnotateCapitals results)
return (trained, bestInput) return (trained, bestInput)
) (net0, i0) $ replicate 10 sequenceSize ) (net0, i0) $ replicate 10 (sequenceSize opts)
case savePath of case savePath opts of
Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput) Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput)
Nothing -> return () Nothing -> return ()
@ -122,12 +121,12 @@ generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes
-> Vector a -> Vector a
-> S ('D1 n) -> S ('D1 n)
-> IO [a] -> IO [a]
generateParagraph n s temperature hotmap hotdict = generateParagraph n s temp hotmap hotdict =
go s go s
where where
go x y = go x y =
do let (_, ns, o) = runRecurrent n x y do let (_, ns, o) = runRecurrent n x y
un <- sample temperature hotdict o un <- sample temp hotdict o
Just re <- return $ makeHot hotmap un Just re <- return $ makeHot hotmap un
rest <- unsafeInterleaveIO $ go ns re rest <- unsafeInterleaveIO $ go ns re
return (un : rest) return (un : rest)

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
@ -139,8 +138,8 @@ instance ( KnownNat channels
, KnownNat (kernelRows * kernelColumns * channels) , KnownNat (kernelRows * kernelColumns * channels)
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where ) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns) type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns)
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) = runUpdate lp (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
in Convolution newKernel newMomentum in Convolution newKernel newMomentum
createRandom = randomConvolution createRandom = randomConvolution

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
@ -138,8 +137,8 @@ instance ( KnownNat channels
, KnownNat (kernelRows * kernelColumns * filters) , KnownNat (kernelRows * kernelColumns * filters)
) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where ) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns) type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns)
runUpdate LearningParameters {..} (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) = runUpdate lp (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
in Deconvolution newKernel newMomentum in Deconvolution newKernel newMomentum
createRandom = randomDeconvolution createRandom = randomDeconvolution

View File

@ -1,5 +1,4 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
@ -38,9 +37,9 @@ instance Show (FullyConnected i o) where
instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
type Gradient (FullyConnected i o) = (FullyConnected' i o) type Gradient (FullyConnected i o) = (FullyConnected' i o)
runUpdate LearningParameters {..} (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) = runUpdate lp (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
let (newBias, newBiasMomentum) = descendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum let (newBias, newBiasMomentum) = descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldBias biasGradient oldBiasMomentum
(newActivations, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum (newActivations, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldActivations activationGradient oldMomentum
in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum) in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum)
createRandom = randomFullyConnected createRandom = randomFullyConnected

View File

@ -6,7 +6,6 @@
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Grenade.Recurrent.Core.Runner ( module Grenade.Recurrent.Core.Runner (
runRecurrentExamples runRecurrentExamples
@ -101,11 +100,11 @@ updateRecInputs :: Fractional (RecurrentInputs sublayers)
-> RecurrentInputs sublayers -> RecurrentInputs sublayers
-> RecurrentInputs sublayers -> RecurrentInputs sublayers
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys) updateRecInputs lp (() :~~+> xs) (() :~~+> ys)
= () :~~+> updateRecInputs l xs ys = () :~~+> updateRecInputs lp xs ys
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys) updateRecInputs lp (x :~@+> xs) (y :~@+> ys)
= (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys = (realToFrac (1 - learningRate lp * learningRegulariser lp) * x - realToFrac (learningRate lp) * y) :~@+> updateRecInputs lp xs ys
updateRecInputs _ RINil RINil updateRecInputs _ RINil RINil
= RINil = RINil

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
@ -60,11 +59,11 @@ instance Show (BasicRecurrent i o) where
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o) type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) = runUpdate lp (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient let newBiasMomentum = konst (learningMomentum lp) * oldBiasMomentum - konst (learningRate lp) * biasGradient
newBias = oldBias + newBiasMomentum newBias = oldBias + newBiasMomentum
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient newMomentum = konst (learningMomentum lp) * oldMomentum - konst (learningRate lp) * activationGradient
regulariser = konst (learningRegulariser * learningRate) * oldActivations regulariser = konst (learningRegulariser lp * learningRate lp) * oldActivations
newActivations = oldActivations + newMomentum - regulariser newActivations = oldActivations + newMomentum - regulariser
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum in BasicRecurrent newBias newBiasMomentum newActivations newMomentum

View File

@ -3,7 +3,6 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
@ -75,7 +74,7 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
-- Run the update function for each group matrix/vector of weights, momentums and gradients. -- Run the update function for each group matrix/vector of weights, momentums and gradients.
-- Hmm, maybe the function should be used instead of passing in the learning parameters. -- Hmm, maybe the function should be used instead of passing in the learning parameters.
runUpdate LearningParameters {..} (LSTM w m) g = runUpdate lp (LSTM w m) g =
let (wf, wf') = u lstmWf w m g let (wf, wf') = u lstmWf w m g
(uf, uf') = u lstmUf w m g (uf, uf') = u lstmUf w m g
(bf, bf') = v lstmBf w m g (bf, bf') = v lstmBf w m g
@ -92,11 +91,11 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
-- Utility function for updating with the momentum, gradients, and weights. -- Utility function for updating with the momentum, gradients, and weights.
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix)) u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
u e (e -> weights) (e -> momentum) (e -> gradient) = u e (e -> weights) (e -> momentum) (e -> gradient) =
descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix)) v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
v e (e -> weights) (e -> momentum) (e -> gradient) = v e (e -> weights) (e -> momentum) (e -> gradient) =
descendVector learningRate learningMomentum learningRegulariser weights gradient momentum descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
-- There's a lot of updates here, so to try and minimise the number of data copies -- There's a lot of updates here, so to try and minimise the number of data copies
-- we'll create a mutable bucket for each. -- we'll create a mutable bucket for each.
@ -137,18 +136,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o) type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
-- Forward propagation for the LSTM layer. -- Forward propagation for the LSTM layer.
-- The size of the cell state is also the size of the output. -- The size of the cell state is also the size of the output.
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) = runRecurrentForwards (LSTM lw _) (S1D cell) (S1D input) =
let -- Forget state vector let -- Forget state vector
f_s = lstmBf + lstmWf #> input + lstmUf #> cell f_s = lstmBf lw + lstmWf lw #> input + lstmUf lw #> cell
f_t = sigmoid f_s f_t = sigmoid f_s
-- Input state vector -- Input state vector
i_s = lstmBi + lstmWi #> input + lstmUi #> cell i_s = lstmBi lw + lstmWi lw #> input + lstmUi lw #> cell
i_t = sigmoid i_s i_t = sigmoid i_s
-- Output state vector -- Output state vector
o_s = lstmBo + lstmWo #> input + lstmUo #> cell o_s = lstmBo lw + lstmWo lw #> input + lstmUo lw #> cell
o_t = sigmoid o_s o_t = sigmoid o_s
-- Cell input state vector -- Cell input state vector
c_s = lstmBc + lstmWc #> input c_s = lstmBc lw + lstmWc lw #> input
c_x = tanh c_s c_x = tanh c_s
-- Cell state -- Cell state
c_t = f_t * cell + i_t * c_x c_t = f_t * cell + i_t * c_x
@ -162,7 +161,7 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
-- --
-- There's a test version using the AD library without hmatrix in the test -- There's a test version using the AD library without hmatrix in the test
-- suite. These should match always. -- suite. These should match always.
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') = runRecurrentBackwards (LSTM lw _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
let -- Reverse Mode AD Derivitives let -- Reverse Mode AD Derivitives
c_t' = h_t' * o_t + cellGrad c_t' = h_t' * o_t + cellGrad
@ -179,8 +178,8 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
c_s' = tanh' c_s * c_x' c_s' = tanh' c_s * c_x'
-- The derivatives to pass sideways (recurrent) and downwards -- The derivatives to pass sideways (recurrent) and downwards
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t cell' = tr (lstmUf lw) #> f_s' + tr (lstmUo lw) #> o_s' + tr (lstmUi lw) #> i_s' + c_t' * f_t
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s' input' = tr (lstmWf lw) #> f_s' + tr (lstmWo lw) #> o_s' + tr (lstmWi lw) #> i_s' + tr (lstmWc lw) #> c_s'
-- Calculate the gradient Matricies for the input -- Calculate the gradient Matricies for the input
lstmWf' = f_s' `outer` input lstmWf' = f_s' `outer` input
@ -239,18 +238,18 @@ tanh' :: (Floating a) => a -> a
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
put (LSTM LSTMWeights {..} _) = do put (LSTM lw _) = do
u lstmWf u (lstmWf lw)
u lstmUf u (lstmUf lw)
v lstmBf v (lstmBf lw)
u lstmWi u (lstmWi lw)
u lstmUi u (lstmUi lw)
v lstmBi v (lstmBi lw)
u lstmWo u (lstmWo lw)
u lstmUo u (lstmUo lw)
v lstmBo v (lstmBo lw)
u lstmWc u (lstmWc lw)
v lstmBc v (lstmBc lw)
where where
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a) u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
u = putListOf put . LA.toList . LA.flatten . extract u = putListOf put . LA.toList . LA.flatten . extract
@ -258,18 +257,8 @@ instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
v = putListOf put . LA.toList . extract v = putListOf put . LA.toList . extract
get = do get = do
lstmWf <- u w <- LSTMWeights <$> u <*> u <*> v <*> u <*> u <*> v <*> u <*> u <*> v <*> u <*> v
lstmUf <- u return $ LSTM w (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
lstmBf <- v
lstmWi <- u
lstmUi <- u
lstmBi <- v
lstmWo <- u
lstmUo <- u
lstmBo <- v
lstmWc <- u
lstmBc <- v
return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
where where
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a) u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
u = let f = fromIntegral $ natVal (Proxy :: Proxy a) u = let f = fromIntegral $ natVal (Proxy :: Proxy a)

View File

@ -6,7 +6,6 @@
{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
@ -15,9 +14,13 @@ module Test.Grenade.Recurrent.Layers.LSTM.Reference where
import Data.Reflection import Data.Reflection
import Numeric.AD.Mode.Reverse import Numeric.AD.Mode.Reverse
import Numeric.AD.Internal.Reverse ( Tape ) import Numeric.AD.Internal.Reverse (Tape)
import GHC.TypeLits (KnownNat)
import Grenade.Recurrent.Layers.LSTM (LSTMWeights (..))
import qualified Grenade.Recurrent.Layers.LSTM as LSTM import qualified Grenade.Recurrent.Layers.LSTM as LSTM
import qualified Numeric.LinearAlgebra.Static as S import qualified Numeric.LinearAlgebra.Static as S
import qualified Numeric.LinearAlgebra as H import qualified Numeric.LinearAlgebra as H
@ -54,31 +57,32 @@ data RefLSTM a = RefLSTM
, refLstmBc :: Vector a -- Bias Cell (b_c) , refLstmBc :: Vector a -- Bias Cell (b_c)
} deriving (Functor, Foldable, Traversable, Eq, Show) } deriving (Functor, Foldable, Traversable, Eq, Show)
lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double lstmToReference :: (KnownNat a, KnownNat b) => LSTM.LSTMWeights a b -> RefLSTM Double
lstmToReference LSTM.LSTMWeights {..} = lstmToReference lw =
let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f) RefLSTM
refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f) { refLstmWf = Matrix . H.toLists . S.extract $ lstmWf lw -- Weight Forget (W_f)
refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f) , refLstmUf = Matrix . H.toLists . S.extract $ lstmUf lw -- Cell State Forget (U_f)
refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i) , refLstmBf = Vector . H.toList . S.extract $ lstmBf lw -- Bias Forget (b_f)
refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i) , refLstmWi = Matrix . H.toLists . S.extract $ lstmWi lw -- Weight Input (W_i)
refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i) , refLstmUi = Matrix . H.toLists . S.extract $ lstmUi lw -- Cell State Input (U_i)
refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o) , refLstmBi = Vector . H.toList . S.extract $ lstmBi lw -- Bias Input (b_i)
refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o) , refLstmWo = Matrix . H.toLists . S.extract $ lstmWo lw -- Weight Output (W_o)
refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o) , refLstmUo = Matrix . H.toLists . S.extract $ lstmUo lw -- Cell State Output (U_o)
refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c) , refLstmBo = Vector . H.toList . S.extract $ lstmBo lw -- Bias Output (b_o)
refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c) , refLstmWc = Matrix . H.toLists . S.extract $ lstmWc lw -- Weight Cell (W_c)
in RefLSTM {..} , refLstmBc = Vector . H.toList . S.extract $ lstmBc lw -- Bias Cell (b_c)
}
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a) runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
runLSTM RefLSTM {..} cell input = runLSTM rl cell input =
let -- Forget state vector let -- Forget state vector
f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell f_t = sigmoid $ refLstmBf rl #+ refLstmWf rl #> input #+ refLstmUf rl #> cell
-- Input state vector -- Input state vector
i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell i_t = sigmoid $ refLstmBi rl #+ refLstmWi rl #> input #+ refLstmUi rl #> cell
-- Output state vector -- Output state vector
o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell o_t = sigmoid $ refLstmBo rl #+ refLstmWo rl #> input #+ refLstmUo rl #> cell
-- Cell input state vector -- Cell input state vector
c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input c_x = fmap tanh $ refLstmBc rl #+ refLstmWc rl #> input
-- Cell state -- Cell state
c_t = f_t #* cell #+ i_t #* c_x c_t = f_t #* cell #+ i_t #* c_x
-- Output (it's sometimes recommended to use tanh c_t) -- Output (it's sometimes recommended to use tanh c_t)