Remove use of RecordWildCards

This commit is contained in:
Erik de Castro Lopo 2020-04-11 17:18:23 +10:00
parent f19da0f74d
commit 5206c95c42
8 changed files with 91 additions and 104 deletions

View File

@ -1,6 +1,5 @@
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
@ -83,34 +82,34 @@ loadShakespeare path = do
return (V.fromList hot, m, cs) return (V.fromList hot, m, cs)
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian) trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
trainSlice !rate !net !recIns input offset size = trainSlice !lrate !net !recIns input offset size =
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
in case reverse e of in case reverse e of
(o : l : xs) -> (o : l : xs) ->
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs) let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
in trainRecurrent rate net recIns examples in trainRecurrent lrate net recIns examples
_ -> error "Not enough input" _ -> error "Not enough input"
where where
x = fromMaybe (error "Hot variable didn't fit.") x = fromMaybe (error "Hot variable didn't fit.")
runShakespeare :: ShakespeareOpts -> ExceptT String IO () runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
runShakespeare ShakespeareOpts {..} = do runShakespeare opts = do
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile (shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare $ trainingFile opts
(net0, i0) <- lift $ (net0, i0) <- lift $
case loadPath of case loadPath opts of
Just loadFile -> netLoad loadFile Just loadFile -> netLoad loadFile
Nothing -> (,0) <$> randomNet Nothing -> (,0) <$> randomNet
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do (trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
xs <- take (iterations `div` 10) <$> getRandomRs (0, length shakespeare - size - 1) xs <- take (iterations opts `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice (rate opts) n i shakespeare offset size) (net, io) xs
results <- take 1000 <$> generateParagraph trained bestInput temperature oneHotMap oneHotDictionary ( S1D $ konst 0) results <- take 1000 <$> generateParagraph trained bestInput (temperature opts) oneHotMap oneHotDictionary ( S1D $ konst 0)
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size) putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
putStrLn (unAnnotateCapitals results) putStrLn (unAnnotateCapitals results)
return (trained, bestInput) return (trained, bestInput)
) (net0, i0) $ replicate 10 sequenceSize ) (net0, i0) $ replicate 10 (sequenceSize opts)
case savePath of case savePath opts of
Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput) Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput)
Nothing -> return () Nothing -> return ()
@ -122,12 +121,12 @@ generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes
-> Vector a -> Vector a
-> S ('D1 n) -> S ('D1 n)
-> IO [a] -> IO [a]
generateParagraph n s temperature hotmap hotdict = generateParagraph n s temp hotmap hotdict =
go s go s
where where
go x y = go x y =
do let (_, ns, o) = runRecurrent n x y do let (_, ns, o) = runRecurrent n x y
un <- sample temperature hotdict o un <- sample temp hotdict o
Just re <- return $ makeHot hotmap un Just re <- return $ makeHot hotmap un
rest <- unsafeInterleaveIO $ go ns re rest <- unsafeInterleaveIO $ go ns re
return (un : rest) return (un : rest)

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
@ -139,8 +138,8 @@ instance ( KnownNat channels
, KnownNat (kernelRows * kernelColumns * channels) , KnownNat (kernelRows * kernelColumns * channels)
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where ) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns) type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns)
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) = runUpdate lp (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
in Convolution newKernel newMomentum in Convolution newKernel newMomentum
createRandom = randomConvolution createRandom = randomConvolution

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
@ -138,8 +137,8 @@ instance ( KnownNat channels
, KnownNat (kernelRows * kernelColumns * filters) , KnownNat (kernelRows * kernelColumns * filters)
) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where ) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns) type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns)
runUpdate LearningParameters {..} (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) = runUpdate lp (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
in Deconvolution newKernel newMomentum in Deconvolution newKernel newMomentum
createRandom = randomDeconvolution createRandom = randomDeconvolution

View File

@ -1,5 +1,4 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
@ -38,9 +37,9 @@ instance Show (FullyConnected i o) where
instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
type Gradient (FullyConnected i o) = (FullyConnected' i o) type Gradient (FullyConnected i o) = (FullyConnected' i o)
runUpdate LearningParameters {..} (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) = runUpdate lp (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
let (newBias, newBiasMomentum) = descendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum let (newBias, newBiasMomentum) = descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldBias biasGradient oldBiasMomentum
(newActivations, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum (newActivations, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldActivations activationGradient oldMomentum
in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum) in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum)
createRandom = randomFullyConnected createRandom = randomFullyConnected

View File

@ -6,7 +6,6 @@
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Grenade.Recurrent.Core.Runner ( module Grenade.Recurrent.Core.Runner (
runRecurrentExamples runRecurrentExamples
@ -101,11 +100,11 @@ updateRecInputs :: Fractional (RecurrentInputs sublayers)
-> RecurrentInputs sublayers -> RecurrentInputs sublayers
-> RecurrentInputs sublayers -> RecurrentInputs sublayers
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys) updateRecInputs lp (() :~~+> xs) (() :~~+> ys)
= () :~~+> updateRecInputs l xs ys = () :~~+> updateRecInputs lp xs ys
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys) updateRecInputs lp (x :~@+> xs) (y :~@+> ys)
= (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys = (realToFrac (1 - learningRate lp * learningRegulariser lp) * x - realToFrac (learningRate lp) * y) :~@+> updateRecInputs lp xs ys
updateRecInputs _ RINil RINil updateRecInputs _ RINil RINil
= RINil = RINil

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
@ -60,11 +59,11 @@ instance Show (BasicRecurrent i o) where
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o) type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) = runUpdate lp (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient let newBiasMomentum = konst (learningMomentum lp) * oldBiasMomentum - konst (learningRate lp) * biasGradient
newBias = oldBias + newBiasMomentum newBias = oldBias + newBiasMomentum
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient newMomentum = konst (learningMomentum lp) * oldMomentum - konst (learningRate lp) * activationGradient
regulariser = konst (learningRegulariser * learningRate) * oldActivations regulariser = konst (learningRegulariser lp * learningRate lp) * oldActivations
newActivations = oldActivations + newMomentum - regulariser newActivations = oldActivations + newMomentum - regulariser
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum in BasicRecurrent newBias newBiasMomentum newActivations newMomentum

View File

@ -3,7 +3,6 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE MultiParamTypeClasses #-}
@ -75,7 +74,7 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
-- Run the update function for each group matrix/vector of weights, momentums and gradients. -- Run the update function for each group matrix/vector of weights, momentums and gradients.
-- Hmm, maybe the function should be used instead of passing in the learning parameters. -- Hmm, maybe the function should be used instead of passing in the learning parameters.
runUpdate LearningParameters {..} (LSTM w m) g = runUpdate lp (LSTM w m) g =
let (wf, wf') = u lstmWf w m g let (wf, wf') = u lstmWf w m g
(uf, uf') = u lstmUf w m g (uf, uf') = u lstmUf w m g
(bf, bf') = v lstmBf w m g (bf, bf') = v lstmBf w m g
@ -92,11 +91,11 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
-- Utility function for updating with the momentum, gradients, and weights. -- Utility function for updating with the momentum, gradients, and weights.
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix)) u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
u e (e -> weights) (e -> momentum) (e -> gradient) = u e (e -> weights) (e -> momentum) (e -> gradient) =
descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix)) v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
v e (e -> weights) (e -> momentum) (e -> gradient) = v e (e -> weights) (e -> momentum) (e -> gradient) =
descendVector learningRate learningMomentum learningRegulariser weights gradient momentum descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
-- There's a lot of updates here, so to try and minimise the number of data copies -- There's a lot of updates here, so to try and minimise the number of data copies
-- we'll create a mutable bucket for each. -- we'll create a mutable bucket for each.
@ -137,18 +136,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o) type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
-- Forward propagation for the LSTM layer. -- Forward propagation for the LSTM layer.
-- The size of the cell state is also the size of the output. -- The size of the cell state is also the size of the output.
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) = runRecurrentForwards (LSTM lw _) (S1D cell) (S1D input) =
let -- Forget state vector let -- Forget state vector
f_s = lstmBf + lstmWf #> input + lstmUf #> cell f_s = lstmBf lw + lstmWf lw #> input + lstmUf lw #> cell
f_t = sigmoid f_s f_t = sigmoid f_s
-- Input state vector -- Input state vector
i_s = lstmBi + lstmWi #> input + lstmUi #> cell i_s = lstmBi lw + lstmWi lw #> input + lstmUi lw #> cell
i_t = sigmoid i_s i_t = sigmoid i_s
-- Output state vector -- Output state vector
o_s = lstmBo + lstmWo #> input + lstmUo #> cell o_s = lstmBo lw + lstmWo lw #> input + lstmUo lw #> cell
o_t = sigmoid o_s o_t = sigmoid o_s
-- Cell input state vector -- Cell input state vector
c_s = lstmBc + lstmWc #> input c_s = lstmBc lw + lstmWc lw #> input
c_x = tanh c_s c_x = tanh c_s
-- Cell state -- Cell state
c_t = f_t * cell + i_t * c_x c_t = f_t * cell + i_t * c_x
@ -162,7 +161,7 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
-- --
-- There's a test version using the AD library without hmatrix in the test -- There's a test version using the AD library without hmatrix in the test
-- suite. These should match always. -- suite. These should match always.
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') = runRecurrentBackwards (LSTM lw _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
let -- Reverse Mode AD Derivitives let -- Reverse Mode AD Derivitives
c_t' = h_t' * o_t + cellGrad c_t' = h_t' * o_t + cellGrad
@ -179,8 +178,8 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
c_s' = tanh' c_s * c_x' c_s' = tanh' c_s * c_x'
-- The derivatives to pass sideways (recurrent) and downwards -- The derivatives to pass sideways (recurrent) and downwards
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t cell' = tr (lstmUf lw) #> f_s' + tr (lstmUo lw) #> o_s' + tr (lstmUi lw) #> i_s' + c_t' * f_t
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s' input' = tr (lstmWf lw) #> f_s' + tr (lstmWo lw) #> o_s' + tr (lstmWi lw) #> i_s' + tr (lstmWc lw) #> c_s'
-- Calculate the gradient Matricies for the input -- Calculate the gradient Matricies for the input
lstmWf' = f_s' `outer` input lstmWf' = f_s' `outer` input
@ -239,18 +238,18 @@ tanh' :: (Floating a) => a -> a
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
put (LSTM LSTMWeights {..} _) = do put (LSTM lw _) = do
u lstmWf u (lstmWf lw)
u lstmUf u (lstmUf lw)
v lstmBf v (lstmBf lw)
u lstmWi u (lstmWi lw)
u lstmUi u (lstmUi lw)
v lstmBi v (lstmBi lw)
u lstmWo u (lstmWo lw)
u lstmUo u (lstmUo lw)
v lstmBo v (lstmBo lw)
u lstmWc u (lstmWc lw)
v lstmBc v (lstmBc lw)
where where
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a) u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
u = putListOf put . LA.toList . LA.flatten . extract u = putListOf put . LA.toList . LA.flatten . extract
@ -258,18 +257,8 @@ instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
v = putListOf put . LA.toList . extract v = putListOf put . LA.toList . extract
get = do get = do
lstmWf <- u w <- LSTMWeights <$> u <*> u <*> v <*> u <*> u <*> v <*> u <*> u <*> v <*> u <*> v
lstmUf <- u return $ LSTM w (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
lstmBf <- v
lstmWi <- u
lstmUi <- u
lstmBi <- v
lstmWo <- u
lstmUo <- u
lstmBo <- v
lstmWc <- u
lstmBc <- v
return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
where where
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a) u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
u = let f = fromIntegral $ natVal (Proxy :: Proxy a) u = let f = fromIntegral $ natVal (Proxy :: Proxy a)

View File

@ -6,7 +6,6 @@
{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
@ -17,7 +16,11 @@ import Data.Reflection
import Numeric.AD.Mode.Reverse import Numeric.AD.Mode.Reverse
import Numeric.AD.Internal.Reverse (Tape) import Numeric.AD.Internal.Reverse (Tape)
import GHC.TypeLits (KnownNat)
import Grenade.Recurrent.Layers.LSTM (LSTMWeights (..))
import qualified Grenade.Recurrent.Layers.LSTM as LSTM import qualified Grenade.Recurrent.Layers.LSTM as LSTM
import qualified Numeric.LinearAlgebra.Static as S import qualified Numeric.LinearAlgebra.Static as S
import qualified Numeric.LinearAlgebra as H import qualified Numeric.LinearAlgebra as H
@ -54,31 +57,32 @@ data RefLSTM a = RefLSTM
, refLstmBc :: Vector a -- Bias Cell (b_c) , refLstmBc :: Vector a -- Bias Cell (b_c)
} deriving (Functor, Foldable, Traversable, Eq, Show) } deriving (Functor, Foldable, Traversable, Eq, Show)
lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double lstmToReference :: (KnownNat a, KnownNat b) => LSTM.LSTMWeights a b -> RefLSTM Double
lstmToReference LSTM.LSTMWeights {..} = lstmToReference lw =
let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f) RefLSTM
refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f) { refLstmWf = Matrix . H.toLists . S.extract $ lstmWf lw -- Weight Forget (W_f)
refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f) , refLstmUf = Matrix . H.toLists . S.extract $ lstmUf lw -- Cell State Forget (U_f)
refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i) , refLstmBf = Vector . H.toList . S.extract $ lstmBf lw -- Bias Forget (b_f)
refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i) , refLstmWi = Matrix . H.toLists . S.extract $ lstmWi lw -- Weight Input (W_i)
refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i) , refLstmUi = Matrix . H.toLists . S.extract $ lstmUi lw -- Cell State Input (U_i)
refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o) , refLstmBi = Vector . H.toList . S.extract $ lstmBi lw -- Bias Input (b_i)
refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o) , refLstmWo = Matrix . H.toLists . S.extract $ lstmWo lw -- Weight Output (W_o)
refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o) , refLstmUo = Matrix . H.toLists . S.extract $ lstmUo lw -- Cell State Output (U_o)
refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c) , refLstmBo = Vector . H.toList . S.extract $ lstmBo lw -- Bias Output (b_o)
refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c) , refLstmWc = Matrix . H.toLists . S.extract $ lstmWc lw -- Weight Cell (W_c)
in RefLSTM {..} , refLstmBc = Vector . H.toList . S.extract $ lstmBc lw -- Bias Cell (b_c)
}
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a) runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
runLSTM RefLSTM {..} cell input = runLSTM rl cell input =
let -- Forget state vector let -- Forget state vector
f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell f_t = sigmoid $ refLstmBf rl #+ refLstmWf rl #> input #+ refLstmUf rl #> cell
-- Input state vector -- Input state vector
i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell i_t = sigmoid $ refLstmBi rl #+ refLstmWi rl #> input #+ refLstmUi rl #> cell
-- Output state vector -- Output state vector
o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell o_t = sigmoid $ refLstmBo rl #+ refLstmWo rl #> input #+ refLstmUo rl #> cell
-- Cell input state vector -- Cell input state vector
c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input c_x = fmap tanh $ refLstmBc rl #+ refLstmWc rl #> input
-- Cell state -- Cell state
c_t = f_t #* cell #+ i_t #* c_x c_t = f_t #* cell #+ i_t #* c_x
-- Output (it's sometimes recommended to use tanh c_t) -- Output (it's sometimes recommended to use tanh c_t)