mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-21 21:59:30 +03:00
Remove use of RecordWildCards
This commit is contained in:
parent
f19da0f74d
commit
5206c95c42
@ -1,6 +1,5 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
@ -83,34 +82,34 @@ loadShakespeare path = do
|
||||
return (V.fromList hot, m, cs)
|
||||
|
||||
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
|
||||
trainSlice !rate !net !recIns input offset size =
|
||||
trainSlice !lrate !net !recIns input offset size =
|
||||
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
|
||||
in case reverse e of
|
||||
(o : l : xs) ->
|
||||
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
|
||||
in trainRecurrent rate net recIns examples
|
||||
in trainRecurrent lrate net recIns examples
|
||||
_ -> error "Not enough input"
|
||||
where
|
||||
x = fromMaybe (error "Hot variable didn't fit.")
|
||||
|
||||
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
|
||||
runShakespeare ShakespeareOpts {..} = do
|
||||
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
|
||||
runShakespeare opts = do
|
||||
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare $ trainingFile opts
|
||||
(net0, i0) <- lift $
|
||||
case loadPath of
|
||||
case loadPath opts of
|
||||
Just loadFile -> netLoad loadFile
|
||||
Nothing -> (,0) <$> randomNet
|
||||
|
||||
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
|
||||
xs <- take (iterations `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
|
||||
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
|
||||
results <- take 1000 <$> generateParagraph trained bestInput temperature oneHotMap oneHotDictionary ( S1D $ konst 0)
|
||||
xs <- take (iterations opts `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
|
||||
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice (rate opts) n i shakespeare offset size) (net, io) xs
|
||||
results <- take 1000 <$> generateParagraph trained bestInput (temperature opts) oneHotMap oneHotDictionary ( S1D $ konst 0)
|
||||
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
|
||||
putStrLn (unAnnotateCapitals results)
|
||||
return (trained, bestInput)
|
||||
) (net0, i0) $ replicate 10 sequenceSize
|
||||
) (net0, i0) $ replicate 10 (sequenceSize opts)
|
||||
|
||||
case savePath of
|
||||
case savePath opts of
|
||||
Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput)
|
||||
Nothing -> return ()
|
||||
|
||||
@ -122,12 +121,12 @@ generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes
|
||||
-> Vector a
|
||||
-> S ('D1 n)
|
||||
-> IO [a]
|
||||
generateParagraph n s temperature hotmap hotdict =
|
||||
generateParagraph n s temp hotmap hotdict =
|
||||
go s
|
||||
where
|
||||
go x y =
|
||||
do let (_, ns, o) = runRecurrent n x y
|
||||
un <- sample temperature hotdict o
|
||||
un <- sample temp hotdict o
|
||||
Just re <- return $ makeHot hotmap un
|
||||
rest <- unsafeInterleaveIO $ go ns re
|
||||
return (un : rest)
|
||||
|
@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
@ -139,8 +138,8 @@ instance ( KnownNat channels
|
||||
, KnownNat (kernelRows * kernelColumns * channels)
|
||||
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
||||
type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
|
||||
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
|
||||
runUpdate lp (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
|
||||
let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
|
||||
in Convolution newKernel newMomentum
|
||||
|
||||
createRandom = randomConvolution
|
||||
|
@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
@ -138,8 +137,8 @@ instance ( KnownNat channels
|
||||
, KnownNat (kernelRows * kernelColumns * filters)
|
||||
) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
||||
type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||
runUpdate LearningParameters {..} (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
|
||||
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
|
||||
runUpdate lp (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
|
||||
let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
|
||||
in Deconvolution newKernel newMomentum
|
||||
|
||||
createRandom = randomDeconvolution
|
||||
|
@ -1,5 +1,4 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
@ -38,9 +37,9 @@ instance Show (FullyConnected i o) where
|
||||
instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
|
||||
type Gradient (FullyConnected i o) = (FullyConnected' i o)
|
||||
|
||||
runUpdate LearningParameters {..} (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
|
||||
let (newBias, newBiasMomentum) = descendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum
|
||||
(newActivations, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum
|
||||
runUpdate lp (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
|
||||
let (newBias, newBiasMomentum) = descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldBias biasGradient oldBiasMomentum
|
||||
(newActivations, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldActivations activationGradient oldMomentum
|
||||
in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum)
|
||||
|
||||
createRandom = randomFullyConnected
|
||||
|
@ -6,7 +6,6 @@
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
|
||||
module Grenade.Recurrent.Core.Runner (
|
||||
runRecurrentExamples
|
||||
@ -101,11 +100,11 @@ updateRecInputs :: Fractional (RecurrentInputs sublayers)
|
||||
-> RecurrentInputs sublayers
|
||||
-> RecurrentInputs sublayers
|
||||
|
||||
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
|
||||
= () :~~+> updateRecInputs l xs ys
|
||||
updateRecInputs lp (() :~~+> xs) (() :~~+> ys)
|
||||
= () :~~+> updateRecInputs lp xs ys
|
||||
|
||||
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
|
||||
= (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
|
||||
updateRecInputs lp (x :~@+> xs) (y :~@+> ys)
|
||||
= (realToFrac (1 - learningRate lp * learningRegulariser lp) * x - realToFrac (learningRate lp) * y) :~@+> updateRecInputs lp xs ys
|
||||
|
||||
updateRecInputs _ RINil RINil
|
||||
= RINil
|
||||
|
@ -1,7 +1,6 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
@ -60,11 +59,11 @@ instance Show (BasicRecurrent i o) where
|
||||
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
|
||||
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
|
||||
|
||||
runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
|
||||
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
|
||||
runUpdate lp (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
|
||||
let newBiasMomentum = konst (learningMomentum lp) * oldBiasMomentum - konst (learningRate lp) * biasGradient
|
||||
newBias = oldBias + newBiasMomentum
|
||||
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
|
||||
regulariser = konst (learningRegulariser * learningRate) * oldActivations
|
||||
newMomentum = konst (learningMomentum lp) * oldMomentum - konst (learningRate lp) * activationGradient
|
||||
regulariser = konst (learningRegulariser lp * learningRate lp) * oldActivations
|
||||
newActivations = oldActivations + newMomentum - regulariser
|
||||
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
@ -75,7 +74,7 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
|
||||
|
||||
-- Run the update function for each group matrix/vector of weights, momentums and gradients.
|
||||
-- Hmm, maybe the function should be used instead of passing in the learning parameters.
|
||||
runUpdate LearningParameters {..} (LSTM w m) g =
|
||||
runUpdate lp (LSTM w m) g =
|
||||
let (wf, wf') = u lstmWf w m g
|
||||
(uf, uf') = u lstmUf w m g
|
||||
(bf, bf') = v lstmBf w m g
|
||||
@ -92,11 +91,11 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
|
||||
-- Utility function for updating with the momentum, gradients, and weights.
|
||||
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
|
||||
u e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||
descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
|
||||
descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
|
||||
|
||||
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
|
||||
v e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||
descendVector learningRate learningMomentum learningRegulariser weights gradient momentum
|
||||
descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
|
||||
|
||||
-- There's a lot of updates here, so to try and minimise the number of data copies
|
||||
-- we'll create a mutable bucket for each.
|
||||
@ -137,18 +136,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
||||
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
|
||||
-- Forward propagation for the LSTM layer.
|
||||
-- The size of the cell state is also the size of the output.
|
||||
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
|
||||
runRecurrentForwards (LSTM lw _) (S1D cell) (S1D input) =
|
||||
let -- Forget state vector
|
||||
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
|
||||
f_s = lstmBf lw + lstmWf lw #> input + lstmUf lw #> cell
|
||||
f_t = sigmoid f_s
|
||||
-- Input state vector
|
||||
i_s = lstmBi + lstmWi #> input + lstmUi #> cell
|
||||
i_s = lstmBi lw + lstmWi lw #> input + lstmUi lw #> cell
|
||||
i_t = sigmoid i_s
|
||||
-- Output state vector
|
||||
o_s = lstmBo + lstmWo #> input + lstmUo #> cell
|
||||
o_s = lstmBo lw + lstmWo lw #> input + lstmUo lw #> cell
|
||||
o_t = sigmoid o_s
|
||||
-- Cell input state vector
|
||||
c_s = lstmBc + lstmWc #> input
|
||||
c_s = lstmBc lw + lstmWc lw #> input
|
||||
c_x = tanh c_s
|
||||
-- Cell state
|
||||
c_t = f_t * cell + i_t * c_x
|
||||
@ -162,7 +161,7 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
||||
--
|
||||
-- There's a test version using the AD library without hmatrix in the test
|
||||
-- suite. These should match always.
|
||||
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
|
||||
runRecurrentBackwards (LSTM lw _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
|
||||
let -- Reverse Mode AD Derivitives
|
||||
c_t' = h_t' * o_t + cellGrad
|
||||
|
||||
@ -179,8 +178,8 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
||||
c_s' = tanh' c_s * c_x'
|
||||
|
||||
-- The derivatives to pass sideways (recurrent) and downwards
|
||||
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
|
||||
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'
|
||||
cell' = tr (lstmUf lw) #> f_s' + tr (lstmUo lw) #> o_s' + tr (lstmUi lw) #> i_s' + c_t' * f_t
|
||||
input' = tr (lstmWf lw) #> f_s' + tr (lstmWo lw) #> o_s' + tr (lstmWi lw) #> i_s' + tr (lstmWc lw) #> c_s'
|
||||
|
||||
-- Calculate the gradient Matricies for the input
|
||||
lstmWf' = f_s' `outer` input
|
||||
@ -239,44 +238,34 @@ tanh' :: (Floating a) => a -> a
|
||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
||||
|
||||
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
|
||||
put (LSTM LSTMWeights {..} _) = do
|
||||
u lstmWf
|
||||
u lstmUf
|
||||
v lstmBf
|
||||
u lstmWi
|
||||
u lstmUi
|
||||
v lstmBi
|
||||
u lstmWo
|
||||
u lstmUo
|
||||
v lstmBo
|
||||
u lstmWc
|
||||
v lstmBc
|
||||
where
|
||||
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
|
||||
u = putListOf put . LA.toList . LA.flatten . extract
|
||||
v :: forall a. (KnownNat a) => Putter (R a)
|
||||
v = putListOf put . LA.toList . extract
|
||||
put (LSTM lw _) = do
|
||||
u (lstmWf lw)
|
||||
u (lstmUf lw)
|
||||
v (lstmBf lw)
|
||||
u (lstmWi lw)
|
||||
u (lstmUi lw)
|
||||
v (lstmBi lw)
|
||||
u (lstmWo lw)
|
||||
u (lstmUo lw)
|
||||
v (lstmBo lw)
|
||||
u (lstmWc lw)
|
||||
v (lstmBc lw)
|
||||
where
|
||||
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
|
||||
u = putListOf put . LA.toList . LA.flatten . extract
|
||||
v :: forall a. (KnownNat a) => Putter (R a)
|
||||
v = putListOf put . LA.toList . extract
|
||||
|
||||
get = do
|
||||
lstmWf <- u
|
||||
lstmUf <- u
|
||||
lstmBf <- v
|
||||
lstmWi <- u
|
||||
lstmUi <- u
|
||||
lstmBi <- v
|
||||
lstmWo <- u
|
||||
lstmUo <- u
|
||||
lstmBo <- v
|
||||
lstmWc <- u
|
||||
lstmBc <- v
|
||||
return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
||||
where
|
||||
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
|
||||
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
|
||||
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
|
||||
v :: forall a. (KnownNat a) => Get (R a)
|
||||
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
|
||||
w <- LSTMWeights <$> u <*> u <*> v <*> u <*> u <*> v <*> u <*> u <*> v <*> u <*> v
|
||||
return $ LSTM w (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
||||
where
|
||||
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
|
||||
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
|
||||
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
|
||||
v :: forall a. (KnownNat a) => Get (R a)
|
||||
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
|
||||
|
||||
w0 = konst 0
|
||||
u0 = konst 0
|
||||
v0 = konst 0
|
||||
w0 = konst 0
|
||||
u0 = konst 0
|
||||
v0 = konst 0
|
||||
|
@ -6,7 +6,6 @@
|
||||
{-# LANGUAGE DeriveFunctor #-}
|
||||
{-# LANGUAGE DeriveFoldable #-}
|
||||
{-# LANGUAGE DeriveTraversable #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
@ -15,9 +14,13 @@ module Test.Grenade.Recurrent.Layers.LSTM.Reference where
|
||||
|
||||
import Data.Reflection
|
||||
import Numeric.AD.Mode.Reverse
|
||||
import Numeric.AD.Internal.Reverse ( Tape )
|
||||
import Numeric.AD.Internal.Reverse (Tape)
|
||||
|
||||
import GHC.TypeLits (KnownNat)
|
||||
|
||||
import Grenade.Recurrent.Layers.LSTM (LSTMWeights (..))
|
||||
import qualified Grenade.Recurrent.Layers.LSTM as LSTM
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as S
|
||||
import qualified Numeric.LinearAlgebra as H
|
||||
|
||||
@ -54,31 +57,32 @@ data RefLSTM a = RefLSTM
|
||||
, refLstmBc :: Vector a -- Bias Cell (b_c)
|
||||
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||
|
||||
lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double
|
||||
lstmToReference LSTM.LSTMWeights {..} =
|
||||
let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f)
|
||||
refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f)
|
||||
refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f)
|
||||
refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i)
|
||||
refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i)
|
||||
refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i)
|
||||
refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o)
|
||||
refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o)
|
||||
refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o)
|
||||
refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c)
|
||||
refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c)
|
||||
in RefLSTM {..}
|
||||
lstmToReference :: (KnownNat a, KnownNat b) => LSTM.LSTMWeights a b -> RefLSTM Double
|
||||
lstmToReference lw =
|
||||
RefLSTM
|
||||
{ refLstmWf = Matrix . H.toLists . S.extract $ lstmWf lw -- Weight Forget (W_f)
|
||||
, refLstmUf = Matrix . H.toLists . S.extract $ lstmUf lw -- Cell State Forget (U_f)
|
||||
, refLstmBf = Vector . H.toList . S.extract $ lstmBf lw -- Bias Forget (b_f)
|
||||
, refLstmWi = Matrix . H.toLists . S.extract $ lstmWi lw -- Weight Input (W_i)
|
||||
, refLstmUi = Matrix . H.toLists . S.extract $ lstmUi lw -- Cell State Input (U_i)
|
||||
, refLstmBi = Vector . H.toList . S.extract $ lstmBi lw -- Bias Input (b_i)
|
||||
, refLstmWo = Matrix . H.toLists . S.extract $ lstmWo lw -- Weight Output (W_o)
|
||||
, refLstmUo = Matrix . H.toLists . S.extract $ lstmUo lw -- Cell State Output (U_o)
|
||||
, refLstmBo = Vector . H.toList . S.extract $ lstmBo lw -- Bias Output (b_o)
|
||||
, refLstmWc = Matrix . H.toLists . S.extract $ lstmWc lw -- Weight Cell (W_c)
|
||||
, refLstmBc = Vector . H.toList . S.extract $ lstmBc lw -- Bias Cell (b_c)
|
||||
}
|
||||
|
||||
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
|
||||
runLSTM RefLSTM {..} cell input =
|
||||
runLSTM rl cell input =
|
||||
let -- Forget state vector
|
||||
f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell
|
||||
f_t = sigmoid $ refLstmBf rl #+ refLstmWf rl #> input #+ refLstmUf rl #> cell
|
||||
-- Input state vector
|
||||
i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell
|
||||
i_t = sigmoid $ refLstmBi rl #+ refLstmWi rl #> input #+ refLstmUi rl #> cell
|
||||
-- Output state vector
|
||||
o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell
|
||||
o_t = sigmoid $ refLstmBo rl #+ refLstmWo rl #> input #+ refLstmUo rl #> cell
|
||||
-- Cell input state vector
|
||||
c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input
|
||||
c_x = fmap tanh $ refLstmBc rl #+ refLstmWc rl #> input
|
||||
-- Cell state
|
||||
c_t = f_t #* cell #+ i_t #* c_x
|
||||
-- Output (it's sometimes recommended to use tanh c_t)
|
||||
|
Loading…
Reference in New Issue
Block a user