From 5206c95c423d9755e620f41576470a281ba59c89 Mon Sep 17 00:00:00 2001 From: Erik de Castro Lopo Date: Sat, 11 Apr 2020 17:18:23 +1000 Subject: [PATCH] Remove use of RecordWildCards --- examples/main/shakespeare.hs | 25 +++--- src/Grenade/Layers/Convolution.hs | 5 +- src/Grenade/Layers/Deconvolution.hs | 5 +- src/Grenade/Layers/FullyConnected.hs | 7 +- src/Grenade/Recurrent/Core/Runner.hs | 9 +- .../Recurrent/Layers/BasicRecurrent.hs | 9 +- src/Grenade/Recurrent/Layers/LSTM.hs | 89 ++++++++----------- .../Recurrent/Layers/LSTM/Reference.hs | 46 +++++----- 8 files changed, 91 insertions(+), 104 deletions(-) diff --git a/examples/main/shakespeare.hs b/examples/main/shakespeare.hs index 7c87d80..bd9ff97 100644 --- a/examples/main/shakespeare.hs +++ b/examples/main/shakespeare.hs @@ -1,6 +1,5 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} @@ -83,34 +82,34 @@ loadShakespeare path = do return (V.fromList hot, m, cs) trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian) -trainSlice !rate !net !recIns input offset size = +trainSlice !lrate !net !recIns input offset size = let e = fmap (x . oneHot) . V.toList $ V.slice offset size input in case reverse e of (o : l : xs) -> let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs) - in trainRecurrent rate net recIns examples + in trainRecurrent lrate net recIns examples _ -> error "Not enough input" where x = fromMaybe (error "Hot variable didn't fit.") runShakespeare :: ShakespeareOpts -> ExceptT String IO () -runShakespeare ShakespeareOpts {..} = do - (shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile +runShakespeare opts = do + (shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare $ trainingFile opts (net0, i0) <- lift $ - case loadPath of + case loadPath opts of Just loadFile -> netLoad loadFile Nothing -> (,0) <$> randomNet (trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do - xs <- take (iterations `div` 10) <$> getRandomRs (0, length shakespeare - size - 1) - let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs - results <- take 1000 <$> generateParagraph trained bestInput temperature oneHotMap oneHotDictionary ( S1D $ konst 0) + xs <- take (iterations opts `div` 10) <$> getRandomRs (0, length shakespeare - size - 1) + let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice (rate opts) n i shakespeare offset size) (net, io) xs + results <- take 1000 <$> generateParagraph trained bestInput (temperature opts) oneHotMap oneHotDictionary ( S1D $ konst 0) putStrLn ("TRAINING STEP WITH SIZE: " ++ show size) putStrLn (unAnnotateCapitals results) return (trained, bestInput) - ) (net0, i0) $ replicate 10 sequenceSize + ) (net0, i0) $ replicate 10 (sequenceSize opts) - case savePath of + case savePath opts of Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput) Nothing -> return () @@ -122,12 +121,12 @@ generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes -> Vector a -> S ('D1 n) -> IO [a] -generateParagraph n s temperature hotmap hotdict = +generateParagraph n s temp hotmap hotdict = go s where go x y = do let (_, ns, o) = runRecurrent n x y - un <- sample temperature hotdict o + un <- sample temp hotdict o Just re <- return $ makeHot hotmap un rest <- unsafeInterleaveIO $ go ns re return (un : rest) diff --git a/src/Grenade/Layers/Convolution.hs b/src/Grenade/Layers/Convolution.hs index 11dc9c1..aa0048b 100644 --- a/src/Grenade/Layers/Convolution.hs +++ b/src/Grenade/Layers/Convolution.hs @@ -1,7 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} @@ -139,8 +138,8 @@ instance ( KnownNat channels , KnownNat (kernelRows * kernelColumns * channels) ) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns) - runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) = - let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum + runUpdate lp (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) = + let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum in Convolution newKernel newMomentum createRandom = randomConvolution diff --git a/src/Grenade/Layers/Deconvolution.hs b/src/Grenade/Layers/Deconvolution.hs index 8e3926b..f46d12b 100644 --- a/src/Grenade/Layers/Deconvolution.hs +++ b/src/Grenade/Layers/Deconvolution.hs @@ -1,7 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} @@ -138,8 +137,8 @@ instance ( KnownNat channels , KnownNat (kernelRows * kernelColumns * filters) ) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns) - runUpdate LearningParameters {..} (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) = - let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum + runUpdate lp (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) = + let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum in Deconvolution newKernel newMomentum createRandom = randomDeconvolution diff --git a/src/Grenade/Layers/FullyConnected.hs b/src/Grenade/Layers/FullyConnected.hs index ecaad46..7041658 100644 --- a/src/Grenade/Layers/FullyConnected.hs +++ b/src/Grenade/Layers/FullyConnected.hs @@ -1,5 +1,4 @@ {-# LANGUAGE DataKinds #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -38,9 +37,9 @@ instance Show (FullyConnected i o) where instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where type Gradient (FullyConnected i o) = (FullyConnected' i o) - runUpdate LearningParameters {..} (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) = - let (newBias, newBiasMomentum) = descendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum - (newActivations, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum + runUpdate lp (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) = + let (newBias, newBiasMomentum) = descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldBias biasGradient oldBiasMomentum + (newActivations, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldActivations activationGradient oldMomentum in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum) createRandom = randomFullyConnected diff --git a/src/Grenade/Recurrent/Core/Runner.hs b/src/Grenade/Recurrent/Core/Runner.hs index 8dacadf..68cb585 100644 --- a/src/Grenade/Recurrent/Core/Runner.hs +++ b/src/Grenade/Recurrent/Core/Runner.hs @@ -6,7 +6,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} module Grenade.Recurrent.Core.Runner ( runRecurrentExamples @@ -101,11 +100,11 @@ updateRecInputs :: Fractional (RecurrentInputs sublayers) -> RecurrentInputs sublayers -> RecurrentInputs sublayers -updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys) - = () :~~+> updateRecInputs l xs ys +updateRecInputs lp (() :~~+> xs) (() :~~+> ys) + = () :~~+> updateRecInputs lp xs ys -updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys) - = (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys +updateRecInputs lp (x :~@+> xs) (y :~@+> ys) + = (realToFrac (1 - learningRate lp * learningRegulariser lp) * x - realToFrac (learningRate lp) * y) :~@+> updateRecInputs lp xs ys updateRecInputs _ RINil RINil = RINil diff --git a/src/Grenade/Recurrent/Layers/BasicRecurrent.hs b/src/Grenade/Recurrent/Layers/BasicRecurrent.hs index 5dfa8e7..53b0133 100644 --- a/src/Grenade/Recurrent/Layers/BasicRecurrent.hs +++ b/src/Grenade/Recurrent/Layers/BasicRecurrent.hs @@ -1,7 +1,6 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -60,11 +59,11 @@ instance Show (BasicRecurrent i o) where instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o) - runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) = - let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient + runUpdate lp (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) = + let newBiasMomentum = konst (learningMomentum lp) * oldBiasMomentum - konst (learningRate lp) * biasGradient newBias = oldBias + newBiasMomentum - newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient - regulariser = konst (learningRegulariser * learningRate) * oldActivations + newMomentum = konst (learningMomentum lp) * oldMomentum - konst (learningRate lp) * activationGradient + regulariser = konst (learningRegulariser lp * learningRate lp) * oldActivations newActivations = oldActivations + newMomentum - regulariser in BasicRecurrent newBias newBiasMomentum newActivations newMomentum diff --git a/src/Grenade/Recurrent/Layers/LSTM.hs b/src/Grenade/Recurrent/Layers/LSTM.hs index f0ac79a..2437148 100644 --- a/src/Grenade/Recurrent/Layers/LSTM.hs +++ b/src/Grenade/Recurrent/Layers/LSTM.hs @@ -3,7 +3,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} @@ -75,7 +74,7 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where -- Run the update function for each group matrix/vector of weights, momentums and gradients. -- Hmm, maybe the function should be used instead of passing in the learning parameters. - runUpdate LearningParameters {..} (LSTM w m) g = + runUpdate lp (LSTM w m) g = let (wf, wf') = u lstmWf w m g (uf, uf') = u lstmUf w m g (bf, bf') = v lstmBf w m g @@ -92,11 +91,11 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where -- Utility function for updating with the momentum, gradients, and weights. u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix)) u e (e -> weights) (e -> momentum) (e -> gradient) = - descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum + descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix)) v e (e -> weights) (e -> momentum) (e -> gradient) = - descendVector learningRate learningMomentum learningRegulariser weights gradient momentum + descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum -- There's a lot of updates here, so to try and minimise the number of data copies -- we'll create a mutable bucket for each. @@ -137,18 +136,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o) -- Forward propagation for the LSTM layer. -- The size of the cell state is also the size of the output. - runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) = + runRecurrentForwards (LSTM lw _) (S1D cell) (S1D input) = let -- Forget state vector - f_s = lstmBf + lstmWf #> input + lstmUf #> cell + f_s = lstmBf lw + lstmWf lw #> input + lstmUf lw #> cell f_t = sigmoid f_s -- Input state vector - i_s = lstmBi + lstmWi #> input + lstmUi #> cell + i_s = lstmBi lw + lstmWi lw #> input + lstmUi lw #> cell i_t = sigmoid i_s -- Output state vector - o_s = lstmBo + lstmWo #> input + lstmUo #> cell + o_s = lstmBo lw + lstmWo lw #> input + lstmUo lw #> cell o_t = sigmoid o_s -- Cell input state vector - c_s = lstmBc + lstmWc #> input + c_s = lstmBc lw + lstmWc lw #> input c_x = tanh c_s -- Cell state c_t = f_t * cell + i_t * c_x @@ -162,7 +161,7 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w -- -- There's a test version using the AD library without hmatrix in the test -- suite. These should match always. - runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') = + runRecurrentBackwards (LSTM lw _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') = let -- Reverse Mode AD Derivitives c_t' = h_t' * o_t + cellGrad @@ -179,8 +178,8 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w c_s' = tanh' c_s * c_x' -- The derivatives to pass sideways (recurrent) and downwards - cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t - input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s' + cell' = tr (lstmUf lw) #> f_s' + tr (lstmUo lw) #> o_s' + tr (lstmUi lw) #> i_s' + c_t' * f_t + input' = tr (lstmWf lw) #> f_s' + tr (lstmWo lw) #> o_s' + tr (lstmWi lw) #> i_s' + tr (lstmWc lw) #> c_s' -- Calculate the gradient Matricies for the input lstmWf' = f_s' `outer` input @@ -239,44 +238,34 @@ tanh' :: (Floating a) => a -> a tanh' t = 1 - s ^ (2 :: Int) where s = tanh t instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where - put (LSTM LSTMWeights {..} _) = do - u lstmWf - u lstmUf - v lstmBf - u lstmWi - u lstmUi - v lstmBi - u lstmWo - u lstmUo - v lstmBo - u lstmWc - v lstmBc - where - u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a) - u = putListOf put . LA.toList . LA.flatten . extract - v :: forall a. (KnownNat a) => Putter (R a) - v = putListOf put . LA.toList . extract + put (LSTM lw _) = do + u (lstmWf lw) + u (lstmUf lw) + v (lstmBf lw) + u (lstmWi lw) + u (lstmUi lw) + v (lstmBi lw) + u (lstmWo lw) + u (lstmUo lw) + v (lstmBo lw) + u (lstmWc lw) + v (lstmBc lw) + where + u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a) + u = putListOf put . LA.toList . LA.flatten . extract + v :: forall a. (KnownNat a) => Putter (R a) + v = putListOf put . LA.toList . extract get = do - lstmWf <- u - lstmUf <- u - lstmBf <- v - lstmWi <- u - lstmUi <- u - lstmBi <- v - lstmWo <- u - lstmUo <- u - lstmBo <- v - lstmWc <- u - lstmBc <- v - return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0) - where - u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a) - u = let f = fromIntegral $ natVal (Proxy :: Proxy a) - in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get - v :: forall a. (KnownNat a) => Get (R a) - v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get + w <- LSTMWeights <$> u <*> u <*> v <*> u <*> u <*> v <*> u <*> u <*> v <*> u <*> v + return $ LSTM w (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0) + where + u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a) + u = let f = fromIntegral $ natVal (Proxy :: Proxy a) + in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get + v :: forall a. (KnownNat a) => Get (R a) + v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get - w0 = konst 0 - u0 = konst 0 - v0 = konst 0 + w0 = konst 0 + u0 = konst 0 + v0 = konst 0 diff --git a/test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs b/test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs index f9373da..ca04eb0 100644 --- a/test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs +++ b/test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs @@ -6,7 +6,6 @@ {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} @@ -15,9 +14,13 @@ module Test.Grenade.Recurrent.Layers.LSTM.Reference where import Data.Reflection import Numeric.AD.Mode.Reverse -import Numeric.AD.Internal.Reverse ( Tape ) +import Numeric.AD.Internal.Reverse (Tape) +import GHC.TypeLits (KnownNat) + +import Grenade.Recurrent.Layers.LSTM (LSTMWeights (..)) import qualified Grenade.Recurrent.Layers.LSTM as LSTM + import qualified Numeric.LinearAlgebra.Static as S import qualified Numeric.LinearAlgebra as H @@ -54,31 +57,32 @@ data RefLSTM a = RefLSTM , refLstmBc :: Vector a -- Bias Cell (b_c) } deriving (Functor, Foldable, Traversable, Eq, Show) -lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double -lstmToReference LSTM.LSTMWeights {..} = - let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f) - refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f) - refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f) - refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i) - refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i) - refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i) - refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o) - refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o) - refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o) - refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c) - refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c) - in RefLSTM {..} +lstmToReference :: (KnownNat a, KnownNat b) => LSTM.LSTMWeights a b -> RefLSTM Double +lstmToReference lw = + RefLSTM + { refLstmWf = Matrix . H.toLists . S.extract $ lstmWf lw -- Weight Forget (W_f) + , refLstmUf = Matrix . H.toLists . S.extract $ lstmUf lw -- Cell State Forget (U_f) + , refLstmBf = Vector . H.toList . S.extract $ lstmBf lw -- Bias Forget (b_f) + , refLstmWi = Matrix . H.toLists . S.extract $ lstmWi lw -- Weight Input (W_i) + , refLstmUi = Matrix . H.toLists . S.extract $ lstmUi lw -- Cell State Input (U_i) + , refLstmBi = Vector . H.toList . S.extract $ lstmBi lw -- Bias Input (b_i) + , refLstmWo = Matrix . H.toLists . S.extract $ lstmWo lw -- Weight Output (W_o) + , refLstmUo = Matrix . H.toLists . S.extract $ lstmUo lw -- Cell State Output (U_o) + , refLstmBo = Vector . H.toList . S.extract $ lstmBo lw -- Bias Output (b_o) + , refLstmWc = Matrix . H.toLists . S.extract $ lstmWc lw -- Weight Cell (W_c) + , refLstmBc = Vector . H.toList . S.extract $ lstmBc lw -- Bias Cell (b_c) + } runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a) -runLSTM RefLSTM {..} cell input = +runLSTM rl cell input = let -- Forget state vector - f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell + f_t = sigmoid $ refLstmBf rl #+ refLstmWf rl #> input #+ refLstmUf rl #> cell -- Input state vector - i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell + i_t = sigmoid $ refLstmBi rl #+ refLstmWi rl #> input #+ refLstmUi rl #> cell -- Output state vector - o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell + o_t = sigmoid $ refLstmBo rl #+ refLstmWo rl #> input #+ refLstmUo rl #> cell -- Cell input state vector - c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input + c_x = fmap tanh $ refLstmBc rl #+ refLstmWc rl #> input -- Cell state c_t = f_t #* cell #+ i_t #* c_x -- Output (it's sometimes recommended to use tanh c_t)