mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-21 21:59:30 +03:00
Remove use of RecordWildCards
This commit is contained in:
parent
f19da0f74d
commit
5206c95c42
@ -1,6 +1,5 @@
|
|||||||
{-# LANGUAGE BangPatterns #-}
|
{-# LANGUAGE BangPatterns #-}
|
||||||
{-# LANGUAGE CPP #-}
|
{-# LANGUAGE CPP #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
@ -83,34 +82,34 @@ loadShakespeare path = do
|
|||||||
return (V.fromList hot, m, cs)
|
return (V.fromList hot, m, cs)
|
||||||
|
|
||||||
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
|
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
|
||||||
trainSlice !rate !net !recIns input offset size =
|
trainSlice !lrate !net !recIns input offset size =
|
||||||
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
|
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
|
||||||
in case reverse e of
|
in case reverse e of
|
||||||
(o : l : xs) ->
|
(o : l : xs) ->
|
||||||
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
|
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
|
||||||
in trainRecurrent rate net recIns examples
|
in trainRecurrent lrate net recIns examples
|
||||||
_ -> error "Not enough input"
|
_ -> error "Not enough input"
|
||||||
where
|
where
|
||||||
x = fromMaybe (error "Hot variable didn't fit.")
|
x = fromMaybe (error "Hot variable didn't fit.")
|
||||||
|
|
||||||
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
|
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
|
||||||
runShakespeare ShakespeareOpts {..} = do
|
runShakespeare opts = do
|
||||||
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
|
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare $ trainingFile opts
|
||||||
(net0, i0) <- lift $
|
(net0, i0) <- lift $
|
||||||
case loadPath of
|
case loadPath opts of
|
||||||
Just loadFile -> netLoad loadFile
|
Just loadFile -> netLoad loadFile
|
||||||
Nothing -> (,0) <$> randomNet
|
Nothing -> (,0) <$> randomNet
|
||||||
|
|
||||||
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
|
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
|
||||||
xs <- take (iterations `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
|
xs <- take (iterations opts `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
|
||||||
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
|
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice (rate opts) n i shakespeare offset size) (net, io) xs
|
||||||
results <- take 1000 <$> generateParagraph trained bestInput temperature oneHotMap oneHotDictionary ( S1D $ konst 0)
|
results <- take 1000 <$> generateParagraph trained bestInput (temperature opts) oneHotMap oneHotDictionary ( S1D $ konst 0)
|
||||||
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
|
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
|
||||||
putStrLn (unAnnotateCapitals results)
|
putStrLn (unAnnotateCapitals results)
|
||||||
return (trained, bestInput)
|
return (trained, bestInput)
|
||||||
) (net0, i0) $ replicate 10 sequenceSize
|
) (net0, i0) $ replicate 10 (sequenceSize opts)
|
||||||
|
|
||||||
case savePath of
|
case savePath opts of
|
||||||
Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput)
|
Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput)
|
||||||
Nothing -> return ()
|
Nothing -> return ()
|
||||||
|
|
||||||
@ -122,12 +121,12 @@ generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes
|
|||||||
-> Vector a
|
-> Vector a
|
||||||
-> S ('D1 n)
|
-> S ('D1 n)
|
||||||
-> IO [a]
|
-> IO [a]
|
||||||
generateParagraph n s temperature hotmap hotdict =
|
generateParagraph n s temp hotmap hotdict =
|
||||||
go s
|
go s
|
||||||
where
|
where
|
||||||
go x y =
|
go x y =
|
||||||
do let (_, ns, o) = runRecurrent n x y
|
do let (_, ns, o) = runRecurrent n x y
|
||||||
un <- sample temperature hotdict o
|
un <- sample temp hotdict o
|
||||||
Just re <- return $ makeHot hotmap un
|
Just re <- return $ makeHot hotmap un
|
||||||
rest <- unsafeInterleaveIO $ go ns re
|
rest <- unsafeInterleaveIO $ go ns re
|
||||||
return (un : rest)
|
return (un : rest)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
{-# LANGUAGE CPP #-}
|
{-# LANGUAGE CPP #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
@ -139,8 +138,8 @@ instance ( KnownNat channels
|
|||||||
, KnownNat (kernelRows * kernelColumns * channels)
|
, KnownNat (kernelRows * kernelColumns * channels)
|
||||||
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
||||||
type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns)
|
type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||||
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
|
runUpdate lp (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
|
||||||
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
|
let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
|
||||||
in Convolution newKernel newMomentum
|
in Convolution newKernel newMomentum
|
||||||
|
|
||||||
createRandom = randomConvolution
|
createRandom = randomConvolution
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
{-# LANGUAGE CPP #-}
|
{-# LANGUAGE CPP #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
@ -138,8 +137,8 @@ instance ( KnownNat channels
|
|||||||
, KnownNat (kernelRows * kernelColumns * filters)
|
, KnownNat (kernelRows * kernelColumns * filters)
|
||||||
) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where
|
||||||
type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns)
|
type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns)
|
||||||
runUpdate LearningParameters {..} (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
|
runUpdate lp (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
|
||||||
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
|
let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
|
||||||
in Deconvolution newKernel newMomentum
|
in Deconvolution newKernel newMomentum
|
||||||
|
|
||||||
createRandom = randomDeconvolution
|
createRandom = randomDeconvolution
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
@ -38,9 +37,9 @@ instance Show (FullyConnected i o) where
|
|||||||
instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
|
instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
|
||||||
type Gradient (FullyConnected i o) = (FullyConnected' i o)
|
type Gradient (FullyConnected i o) = (FullyConnected' i o)
|
||||||
|
|
||||||
runUpdate LearningParameters {..} (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
|
runUpdate lp (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
|
||||||
let (newBias, newBiasMomentum) = descendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum
|
let (newBias, newBiasMomentum) = descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldBias biasGradient oldBiasMomentum
|
||||||
(newActivations, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum
|
(newActivations, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldActivations activationGradient oldMomentum
|
||||||
in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum)
|
in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum)
|
||||||
|
|
||||||
createRandom = randomFullyConnected
|
createRandom = randomFullyConnected
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
|
|
||||||
module Grenade.Recurrent.Core.Runner (
|
module Grenade.Recurrent.Core.Runner (
|
||||||
runRecurrentExamples
|
runRecurrentExamples
|
||||||
@ -101,11 +100,11 @@ updateRecInputs :: Fractional (RecurrentInputs sublayers)
|
|||||||
-> RecurrentInputs sublayers
|
-> RecurrentInputs sublayers
|
||||||
-> RecurrentInputs sublayers
|
-> RecurrentInputs sublayers
|
||||||
|
|
||||||
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
|
updateRecInputs lp (() :~~+> xs) (() :~~+> ys)
|
||||||
= () :~~+> updateRecInputs l xs ys
|
= () :~~+> updateRecInputs lp xs ys
|
||||||
|
|
||||||
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
|
updateRecInputs lp (x :~@+> xs) (y :~@+> ys)
|
||||||
= (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
|
= (realToFrac (1 - learningRate lp * learningRegulariser lp) * x - realToFrac (learningRate lp) * y) :~@+> updateRecInputs lp xs ys
|
||||||
|
|
||||||
updateRecInputs _ RINil RINil
|
updateRecInputs _ RINil RINil
|
||||||
= RINil
|
= RINil
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
{-# LANGUAGE CPP #-}
|
{-# LANGUAGE CPP #-}
|
||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
@ -60,11 +59,11 @@ instance Show (BasicRecurrent i o) where
|
|||||||
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
|
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
|
||||||
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
|
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
|
||||||
|
|
||||||
runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
|
runUpdate lp (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
|
||||||
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
|
let newBiasMomentum = konst (learningMomentum lp) * oldBiasMomentum - konst (learningRate lp) * biasGradient
|
||||||
newBias = oldBias + newBiasMomentum
|
newBias = oldBias + newBiasMomentum
|
||||||
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
|
newMomentum = konst (learningMomentum lp) * oldMomentum - konst (learningRate lp) * activationGradient
|
||||||
regulariser = konst (learningRegulariser * learningRate) * oldActivations
|
regulariser = konst (learningRegulariser lp * learningRate lp) * oldActivations
|
||||||
newActivations = oldActivations + newMomentum - regulariser
|
newActivations = oldActivations + newMomentum - regulariser
|
||||||
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum
|
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
{-# LANGUAGE DataKinds #-}
|
{-# LANGUAGE DataKinds #-}
|
||||||
{-# LANGUAGE GADTs #-}
|
{-# LANGUAGE GADTs #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE TypeOperators #-}
|
{-# LANGUAGE TypeOperators #-}
|
||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
@ -75,7 +74,7 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
|
|||||||
|
|
||||||
-- Run the update function for each group matrix/vector of weights, momentums and gradients.
|
-- Run the update function for each group matrix/vector of weights, momentums and gradients.
|
||||||
-- Hmm, maybe the function should be used instead of passing in the learning parameters.
|
-- Hmm, maybe the function should be used instead of passing in the learning parameters.
|
||||||
runUpdate LearningParameters {..} (LSTM w m) g =
|
runUpdate lp (LSTM w m) g =
|
||||||
let (wf, wf') = u lstmWf w m g
|
let (wf, wf') = u lstmWf w m g
|
||||||
(uf, uf') = u lstmUf w m g
|
(uf, uf') = u lstmUf w m g
|
||||||
(bf, bf') = v lstmBf w m g
|
(bf, bf') = v lstmBf w m g
|
||||||
@ -92,11 +91,11 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
|
|||||||
-- Utility function for updating with the momentum, gradients, and weights.
|
-- Utility function for updating with the momentum, gradients, and weights.
|
||||||
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
|
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
|
||||||
u e (e -> weights) (e -> momentum) (e -> gradient) =
|
u e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||||
descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
|
descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
|
||||||
|
|
||||||
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
|
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
|
||||||
v e (e -> weights) (e -> momentum) (e -> gradient) =
|
v e (e -> weights) (e -> momentum) (e -> gradient) =
|
||||||
descendVector learningRate learningMomentum learningRegulariser weights gradient momentum
|
descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
|
||||||
|
|
||||||
-- There's a lot of updates here, so to try and minimise the number of data copies
|
-- There's a lot of updates here, so to try and minimise the number of data copies
|
||||||
-- we'll create a mutable bucket for each.
|
-- we'll create a mutable bucket for each.
|
||||||
@ -137,18 +136,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
|||||||
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
|
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
|
||||||
-- Forward propagation for the LSTM layer.
|
-- Forward propagation for the LSTM layer.
|
||||||
-- The size of the cell state is also the size of the output.
|
-- The size of the cell state is also the size of the output.
|
||||||
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
|
runRecurrentForwards (LSTM lw _) (S1D cell) (S1D input) =
|
||||||
let -- Forget state vector
|
let -- Forget state vector
|
||||||
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
|
f_s = lstmBf lw + lstmWf lw #> input + lstmUf lw #> cell
|
||||||
f_t = sigmoid f_s
|
f_t = sigmoid f_s
|
||||||
-- Input state vector
|
-- Input state vector
|
||||||
i_s = lstmBi + lstmWi #> input + lstmUi #> cell
|
i_s = lstmBi lw + lstmWi lw #> input + lstmUi lw #> cell
|
||||||
i_t = sigmoid i_s
|
i_t = sigmoid i_s
|
||||||
-- Output state vector
|
-- Output state vector
|
||||||
o_s = lstmBo + lstmWo #> input + lstmUo #> cell
|
o_s = lstmBo lw + lstmWo lw #> input + lstmUo lw #> cell
|
||||||
o_t = sigmoid o_s
|
o_t = sigmoid o_s
|
||||||
-- Cell input state vector
|
-- Cell input state vector
|
||||||
c_s = lstmBc + lstmWc #> input
|
c_s = lstmBc lw + lstmWc lw #> input
|
||||||
c_x = tanh c_s
|
c_x = tanh c_s
|
||||||
-- Cell state
|
-- Cell state
|
||||||
c_t = f_t * cell + i_t * c_x
|
c_t = f_t * cell + i_t * c_x
|
||||||
@ -162,7 +161,7 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
|||||||
--
|
--
|
||||||
-- There's a test version using the AD library without hmatrix in the test
|
-- There's a test version using the AD library without hmatrix in the test
|
||||||
-- suite. These should match always.
|
-- suite. These should match always.
|
||||||
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
|
runRecurrentBackwards (LSTM lw _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
|
||||||
let -- Reverse Mode AD Derivitives
|
let -- Reverse Mode AD Derivitives
|
||||||
c_t' = h_t' * o_t + cellGrad
|
c_t' = h_t' * o_t + cellGrad
|
||||||
|
|
||||||
@ -179,8 +178,8 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
|||||||
c_s' = tanh' c_s * c_x'
|
c_s' = tanh' c_s * c_x'
|
||||||
|
|
||||||
-- The derivatives to pass sideways (recurrent) and downwards
|
-- The derivatives to pass sideways (recurrent) and downwards
|
||||||
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
|
cell' = tr (lstmUf lw) #> f_s' + tr (lstmUo lw) #> o_s' + tr (lstmUi lw) #> i_s' + c_t' * f_t
|
||||||
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'
|
input' = tr (lstmWf lw) #> f_s' + tr (lstmWo lw) #> o_s' + tr (lstmWi lw) #> i_s' + tr (lstmWc lw) #> c_s'
|
||||||
|
|
||||||
-- Calculate the gradient Matricies for the input
|
-- Calculate the gradient Matricies for the input
|
||||||
lstmWf' = f_s' `outer` input
|
lstmWf' = f_s' `outer` input
|
||||||
@ -239,44 +238,34 @@ tanh' :: (Floating a) => a -> a
|
|||||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
|
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
|
||||||
put (LSTM LSTMWeights {..} _) = do
|
put (LSTM lw _) = do
|
||||||
u lstmWf
|
u (lstmWf lw)
|
||||||
u lstmUf
|
u (lstmUf lw)
|
||||||
v lstmBf
|
v (lstmBf lw)
|
||||||
u lstmWi
|
u (lstmWi lw)
|
||||||
u lstmUi
|
u (lstmUi lw)
|
||||||
v lstmBi
|
v (lstmBi lw)
|
||||||
u lstmWo
|
u (lstmWo lw)
|
||||||
u lstmUo
|
u (lstmUo lw)
|
||||||
v lstmBo
|
v (lstmBo lw)
|
||||||
u lstmWc
|
u (lstmWc lw)
|
||||||
v lstmBc
|
v (lstmBc lw)
|
||||||
where
|
where
|
||||||
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
|
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
|
||||||
u = putListOf put . LA.toList . LA.flatten . extract
|
u = putListOf put . LA.toList . LA.flatten . extract
|
||||||
v :: forall a. (KnownNat a) => Putter (R a)
|
v :: forall a. (KnownNat a) => Putter (R a)
|
||||||
v = putListOf put . LA.toList . extract
|
v = putListOf put . LA.toList . extract
|
||||||
|
|
||||||
get = do
|
get = do
|
||||||
lstmWf <- u
|
w <- LSTMWeights <$> u <*> u <*> v <*> u <*> u <*> v <*> u <*> u <*> v <*> u <*> v
|
||||||
lstmUf <- u
|
return $ LSTM w (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
||||||
lstmBf <- v
|
where
|
||||||
lstmWi <- u
|
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
|
||||||
lstmUi <- u
|
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
|
||||||
lstmBi <- v
|
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
|
||||||
lstmWo <- u
|
v :: forall a. (KnownNat a) => Get (R a)
|
||||||
lstmUo <- u
|
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
|
||||||
lstmBo <- v
|
|
||||||
lstmWc <- u
|
|
||||||
lstmBc <- v
|
|
||||||
return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
|
|
||||||
where
|
|
||||||
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
|
|
||||||
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
|
|
||||||
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
|
|
||||||
v :: forall a. (KnownNat a) => Get (R a)
|
|
||||||
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
|
|
||||||
|
|
||||||
w0 = konst 0
|
w0 = konst 0
|
||||||
u0 = konst 0
|
u0 = konst 0
|
||||||
v0 = konst 0
|
v0 = konst 0
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
{-# LANGUAGE DeriveFunctor #-}
|
{-# LANGUAGE DeriveFunctor #-}
|
||||||
{-# LANGUAGE DeriveFoldable #-}
|
{-# LANGUAGE DeriveFoldable #-}
|
||||||
{-# LANGUAGE DeriveTraversable #-}
|
{-# LANGUAGE DeriveTraversable #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
|
|
||||||
@ -15,9 +14,13 @@ module Test.Grenade.Recurrent.Layers.LSTM.Reference where
|
|||||||
|
|
||||||
import Data.Reflection
|
import Data.Reflection
|
||||||
import Numeric.AD.Mode.Reverse
|
import Numeric.AD.Mode.Reverse
|
||||||
import Numeric.AD.Internal.Reverse ( Tape )
|
import Numeric.AD.Internal.Reverse (Tape)
|
||||||
|
|
||||||
|
import GHC.TypeLits (KnownNat)
|
||||||
|
|
||||||
|
import Grenade.Recurrent.Layers.LSTM (LSTMWeights (..))
|
||||||
import qualified Grenade.Recurrent.Layers.LSTM as LSTM
|
import qualified Grenade.Recurrent.Layers.LSTM as LSTM
|
||||||
|
|
||||||
import qualified Numeric.LinearAlgebra.Static as S
|
import qualified Numeric.LinearAlgebra.Static as S
|
||||||
import qualified Numeric.LinearAlgebra as H
|
import qualified Numeric.LinearAlgebra as H
|
||||||
|
|
||||||
@ -54,31 +57,32 @@ data RefLSTM a = RefLSTM
|
|||||||
, refLstmBc :: Vector a -- Bias Cell (b_c)
|
, refLstmBc :: Vector a -- Bias Cell (b_c)
|
||||||
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
} deriving (Functor, Foldable, Traversable, Eq, Show)
|
||||||
|
|
||||||
lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double
|
lstmToReference :: (KnownNat a, KnownNat b) => LSTM.LSTMWeights a b -> RefLSTM Double
|
||||||
lstmToReference LSTM.LSTMWeights {..} =
|
lstmToReference lw =
|
||||||
let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f)
|
RefLSTM
|
||||||
refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f)
|
{ refLstmWf = Matrix . H.toLists . S.extract $ lstmWf lw -- Weight Forget (W_f)
|
||||||
refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f)
|
, refLstmUf = Matrix . H.toLists . S.extract $ lstmUf lw -- Cell State Forget (U_f)
|
||||||
refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i)
|
, refLstmBf = Vector . H.toList . S.extract $ lstmBf lw -- Bias Forget (b_f)
|
||||||
refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i)
|
, refLstmWi = Matrix . H.toLists . S.extract $ lstmWi lw -- Weight Input (W_i)
|
||||||
refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i)
|
, refLstmUi = Matrix . H.toLists . S.extract $ lstmUi lw -- Cell State Input (U_i)
|
||||||
refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o)
|
, refLstmBi = Vector . H.toList . S.extract $ lstmBi lw -- Bias Input (b_i)
|
||||||
refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o)
|
, refLstmWo = Matrix . H.toLists . S.extract $ lstmWo lw -- Weight Output (W_o)
|
||||||
refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o)
|
, refLstmUo = Matrix . H.toLists . S.extract $ lstmUo lw -- Cell State Output (U_o)
|
||||||
refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c)
|
, refLstmBo = Vector . H.toList . S.extract $ lstmBo lw -- Bias Output (b_o)
|
||||||
refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c)
|
, refLstmWc = Matrix . H.toLists . S.extract $ lstmWc lw -- Weight Cell (W_c)
|
||||||
in RefLSTM {..}
|
, refLstmBc = Vector . H.toList . S.extract $ lstmBc lw -- Bias Cell (b_c)
|
||||||
|
}
|
||||||
|
|
||||||
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
|
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
|
||||||
runLSTM RefLSTM {..} cell input =
|
runLSTM rl cell input =
|
||||||
let -- Forget state vector
|
let -- Forget state vector
|
||||||
f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell
|
f_t = sigmoid $ refLstmBf rl #+ refLstmWf rl #> input #+ refLstmUf rl #> cell
|
||||||
-- Input state vector
|
-- Input state vector
|
||||||
i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell
|
i_t = sigmoid $ refLstmBi rl #+ refLstmWi rl #> input #+ refLstmUi rl #> cell
|
||||||
-- Output state vector
|
-- Output state vector
|
||||||
o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell
|
o_t = sigmoid $ refLstmBo rl #+ refLstmWo rl #> input #+ refLstmUo rl #> cell
|
||||||
-- Cell input state vector
|
-- Cell input state vector
|
||||||
c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input
|
c_x = fmap tanh $ refLstmBc rl #+ refLstmWc rl #> input
|
||||||
-- Cell state
|
-- Cell state
|
||||||
c_t = f_t #* cell #+ i_t #* c_x
|
c_t = f_t #* cell #+ i_t #* c_x
|
||||||
-- Output (it's sometimes recommended to use tanh c_t)
|
-- Output (it's sometimes recommended to use tanh c_t)
|
||||||
|
Loading…
Reference in New Issue
Block a user