mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Optimise Wengert tape for LSTM
This commit is contained in:
parent
4332d62c71
commit
a5881518bf
@ -5,61 +5,22 @@ import Criterion.Main
|
|||||||
|
|
||||||
import Grenade
|
import Grenade
|
||||||
import Grenade.Recurrent
|
import Grenade.Recurrent
|
||||||
import Grenade.Layers.Internal.Update
|
|
||||||
|
|
||||||
import qualified Numeric.LinearAlgebra.Static as H
|
|
||||||
|
|
||||||
main :: IO ()
|
main :: IO ()
|
||||||
main = do
|
main = do
|
||||||
layer60 :: LSTM 40 60 <- createRandom
|
|
||||||
layer512 :: LSTM 40 512 <- createRandom
|
|
||||||
input40 :: S ('D1 40) <- randomOfShape
|
input40 :: S ('D1 40) <- randomOfShape
|
||||||
rec60 :: S ('D1 60) <- randomOfShape
|
|
||||||
rec512 :: S ('D1 512) <- randomOfShape
|
|
||||||
lstm :: RecNet <- randomRecurrent
|
lstm :: RecNet <- randomRecurrent
|
||||||
|
|
||||||
let upIn60 :: H.R 3600 = H.randomVector 1 H.Uniform * 2 - 1
|
|
||||||
let upIn512 :: H.R 262144 = H.randomVector 1 H.Uniform * 2 - 1
|
|
||||||
|
|
||||||
defaultMain [
|
defaultMain [
|
||||||
bgroup "lstm" [ bench "forwards-60" $ nf (nfT3 . uncurry (testRun60 layer60)) (rec60, input40)
|
bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
|
||||||
, bench "forwards-512" $ nf (nfT3 . uncurry (testRun512 layer512)) (rec512, input40)
|
|
||||||
, bench "backwards-60" $ nf (nfT3 . uncurry4 (testRun60' layer60)) (rec60, input40, rec60, rec60)
|
|
||||||
, bench "backwards-512" $ nf (nfT3 . uncurry4 (testRun512' layer512)) (rec512, input40, rec512, rec512)
|
|
||||||
]
|
|
||||||
, bgroup "update" [ bench "matrix-60x60" $ nf (uncurry3 (descendVector 1 1 1)) (upIn60, upIn60, upIn60)
|
|
||||||
, bench "matrix-512x512" $ nf (uncurry3 (descendVector 1 1 1)) (upIn512, upIn512, upIn512)
|
|
||||||
]
|
|
||||||
, bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
|
|
||||||
, bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40)
|
, bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40)
|
||||||
, bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40)
|
, bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40)
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
testRun60 :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> ((S ('D1 60), S ('D1 40)), S ('D1 60), S ('D1 60))
|
|
||||||
testRun60 = runRecurrentForwards
|
|
||||||
|
|
||||||
testRun60' :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> S ('D1 60) -> S ('D1 60) -> (Gradient (LSTM 40 60), S ('D1 60), S ('D1 40))
|
|
||||||
testRun60' = curry . runRecurrentBackwards
|
|
||||||
|
|
||||||
testRun512 :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> ((S ('D1 512), S ('D1 40)), S ('D1 512), S ('D1 512))
|
|
||||||
testRun512 = runRecurrentForwards
|
|
||||||
|
|
||||||
testRun512' :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> S ('D1 512) -> S ('D1 512) -> (Gradient (LSTM 40 512), S ('D1 512), S ('D1 40))
|
|
||||||
testRun512' = curry . runRecurrentBackwards
|
|
||||||
|
|
||||||
uncurry4 :: (t -> t1 -> t2 -> t3 -> t4) -> (t, t1, t2, t3) -> t4
|
|
||||||
uncurry4 f (a,b,c,d) = f a b c d
|
|
||||||
|
|
||||||
uncurry3 :: (t -> t1 -> t2 -> t3) -> (t, t1, t2) -> t3
|
|
||||||
uncurry3 f (a,b,c) = f a b c
|
|
||||||
|
|
||||||
nfT2 :: (a, b) -> (a, b)
|
nfT2 :: (a, b) -> (a, b)
|
||||||
nfT2 (!a, !b) = (a, b)
|
nfT2 (!a, !b) = (a, b)
|
||||||
|
|
||||||
nfT3 :: (a, b, c) -> (b, c)
|
|
||||||
nfT3 (!_, !b, !c) = (b, c)
|
|
||||||
|
|
||||||
|
|
||||||
type R = Recurrent
|
type R = Recurrent
|
||||||
type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ]
|
type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ]
|
||||||
|
@ -36,12 +36,12 @@ instance Serialize Elu where
|
|||||||
get = return Elu
|
get = return Elu
|
||||||
|
|
||||||
instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where
|
instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where
|
||||||
type Tape Elu ('D1 i) ('D1 i) = S ('D1 i)
|
type Tape Elu ('D1 i) ('D1 i) = LAS.R i
|
||||||
|
|
||||||
runForwards _ (S1D y) = (S1D y, S1D (elu y))
|
runForwards _ (S1D y) = (y, S1D (elu y))
|
||||||
where
|
where
|
||||||
elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a)
|
elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a)
|
||||||
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (elu' y * dEdy))
|
runBackwards _ y (S1D dEdy) = ((), S1D (elu' y * dEdy))
|
||||||
where
|
where
|
||||||
elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1)
|
elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1)
|
||||||
|
|
||||||
|
@ -46,12 +46,12 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
|
|||||||
createRandom = randomFullyConnected
|
createRandom = randomFullyConnected
|
||||||
|
|
||||||
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
|
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
|
||||||
type Tape (FullyConnected i o) ('D1 i) ('D1 o) = S ('D1 i)
|
type Tape (FullyConnected i o) ('D1 i) ('D1 o) = R i
|
||||||
-- Do a matrix vector multiplication and return the result.
|
-- Do a matrix vector multiplication and return the result.
|
||||||
runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (S1D v, S1D (wB + wN #> v))
|
runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (v, S1D (wB + wN #> v))
|
||||||
|
|
||||||
-- Run a backpropogation step for a full connected layer.
|
-- Run a backpropogation step for a full connected layer.
|
||||||
runBackwards (FullyConnected (FullyConnected' _ wN) _) (S1D x) (S1D dEdy) =
|
runBackwards (FullyConnected (FullyConnected' _ wN) _) x (S1D dEdy) =
|
||||||
let wB' = dEdy
|
let wB' = dEdy
|
||||||
mm' = dEdy `outer` x
|
mm' = dEdy `outer` x
|
||||||
-- calcluate derivatives for next step
|
-- calcluate derivatives for next step
|
||||||
|
@ -34,18 +34,24 @@ instance UpdateLayer Logit where
|
|||||||
createRandom = return Logit
|
createRandom = return Logit
|
||||||
|
|
||||||
instance (a ~ b, SingI a) => Layer Logit a b where
|
instance (a ~ b, SingI a) => Layer Logit a b where
|
||||||
|
-- Wengert tape optimisation:
|
||||||
|
--
|
||||||
|
-- Derivative of the sigmoid function is
|
||||||
|
-- d σ(x) / dx = σ(x) • (1 - σ(x))
|
||||||
|
-- but we have already calculated σ(x) in
|
||||||
|
-- the forward pass, so just store that
|
||||||
|
-- and use it in the backwards pass.
|
||||||
type Tape Logit a b = S a
|
type Tape Logit a b = S a
|
||||||
runForwards _ a = (a, logistic a)
|
runForwards _ a =
|
||||||
runBackwards _ a g = ((), logistic' a * g)
|
let l = sigmoid a
|
||||||
|
in (l, l)
|
||||||
|
runBackwards _ l g =
|
||||||
|
let sigmoid' = l * (1 - l)
|
||||||
|
in ((), sigmoid' * g)
|
||||||
|
|
||||||
instance Serialize Logit where
|
instance Serialize Logit where
|
||||||
put _ = return ()
|
put _ = return ()
|
||||||
get = return Logit
|
get = return Logit
|
||||||
|
|
||||||
logistic :: Floating a => a -> a
|
sigmoid :: Floating a => a -> a
|
||||||
logistic x = 1 / (1 + exp (-x))
|
sigmoid x = 1 / (1 + exp (-x))
|
||||||
|
|
||||||
logistic' :: Floating a => a -> a
|
|
||||||
logistic' x = logix * (1 - logix)
|
|
||||||
where
|
|
||||||
logix = logistic x
|
|
||||||
|
@ -127,37 +127,12 @@ instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
|
|||||||
|
|
||||||
instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
|
instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
|
||||||
|
|
||||||
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i))
|
-- The tape stores essentially every variable we calculate,
|
||||||
|
-- so we don't have to run any forwards component again.
|
||||||
|
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
|
||||||
-- Forward propagation for the LSTM layer.
|
-- Forward propagation for the LSTM layer.
|
||||||
-- The size of the cell state is also the size of the output.
|
-- The size of the cell state is also the size of the output.
|
||||||
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
|
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
|
||||||
let -- Forget state vector
|
|
||||||
f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
|
|
||||||
-- Input state vector
|
|
||||||
i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
|
|
||||||
-- Output state vector
|
|
||||||
o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
|
|
||||||
-- Cell input state vector
|
|
||||||
c_x = tanh $ lstmBc + lstmWc #> input
|
|
||||||
-- Cell state
|
|
||||||
c_t = f_t * cell + i_t * c_x
|
|
||||||
-- Output (it's sometimes recommended to use tanh c_t)
|
|
||||||
h_t = o_t * c_t
|
|
||||||
in ((S1D cell, S1D input), S1D c_t, S1D h_t)
|
|
||||||
|
|
||||||
-- Run a backpropogation step for an LSTM layer.
|
|
||||||
-- We're doing all the derivatives by hand here, so one should
|
|
||||||
-- be extra careful when changing this.
|
|
||||||
--
|
|
||||||
-- There's a test version using the AD library without hmatrix in the test
|
|
||||||
-- suite. These should match always.
|
|
||||||
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell, S1D input) (S1D cellGrad) (S1D h_t') =
|
|
||||||
-- We're not keeping the Wengert tape during the forward pass,
|
|
||||||
-- so we're duplicating some work here.
|
|
||||||
--
|
|
||||||
-- If I was being generous, I'd call it checkpointing.
|
|
||||||
--
|
|
||||||
-- Maybe think about better ways to store some intermediate states.
|
|
||||||
let -- Forget state vector
|
let -- Forget state vector
|
||||||
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
|
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
|
||||||
f_t = sigmoid f_s
|
f_t = sigmoid f_s
|
||||||
@ -172,8 +147,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
|||||||
c_x = tanh c_s
|
c_x = tanh c_s
|
||||||
-- Cell state
|
-- Cell state
|
||||||
c_t = f_t * cell + i_t * c_x
|
c_t = f_t * cell + i_t * c_x
|
||||||
|
-- Output (it's sometimes recommended to use tanh c_t)
|
||||||
|
h_t = o_t * c_t
|
||||||
|
in ((cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t), S1D c_t, S1D h_t)
|
||||||
|
|
||||||
-- Reverse Mode AD Derivitives
|
-- Run a backpropogation step for an LSTM layer.
|
||||||
|
-- We're doing all the derivatives by hand here, so one should
|
||||||
|
-- be extra careful when changing this.
|
||||||
|
--
|
||||||
|
-- There's a test version using the AD library without hmatrix in the test
|
||||||
|
-- suite. These should match always.
|
||||||
|
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
|
||||||
|
let -- Reverse Mode AD Derivitives
|
||||||
c_t' = h_t' * o_t + cellGrad
|
c_t' = h_t' * o_t + cellGrad
|
||||||
|
|
||||||
f_t' = c_t' * cell
|
f_t' = c_t' * cell
|
||||||
@ -235,7 +220,8 @@ randomLSTM = do
|
|||||||
|
|
||||||
-- | Maths
|
-- | Maths
|
||||||
--
|
--
|
||||||
-- TODO: move to not here
|
-- TODO: Move to not here
|
||||||
|
-- Optimise backwards derivative
|
||||||
sigmoid :: Floating a => a -> a
|
sigmoid :: Floating a => a -> a
|
||||||
sigmoid x = 1 / (1 + exp (-x))
|
sigmoid x = 1 / (1 + exp (-x))
|
||||||
|
|
||||||
|
@ -65,7 +65,9 @@ prop_lstm_reference_backwards =
|
|||||||
input :: S.R 3 <- forAll randomVector
|
input :: S.R 3 <- forAll randomVector
|
||||||
cell :: S.R 2 <- forAll randomVector
|
cell :: S.R 2 <- forAll randomVector
|
||||||
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
||||||
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
|
||||||
|
= runRecurrentForwards net (S1D cell) (S1D input)
|
||||||
|
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||||
case actualBacks of
|
case actualBacks of
|
||||||
(actualGradients, _, _ :: S ('D1 3)) ->
|
(actualGradients, _, _ :: S ('D1 3)) ->
|
||||||
let refNet = Reference.lstmToReference lstmWeights
|
let refNet = Reference.lstmToReference lstmWeights
|
||||||
@ -79,7 +81,9 @@ prop_lstm_reference_backwards_input =
|
|||||||
input :: S.R 3 <- forAll randomVector
|
input :: S.R 3 <- forAll randomVector
|
||||||
cell :: S.R 2 <- forAll randomVector
|
cell :: S.R 2 <- forAll randomVector
|
||||||
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
||||||
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
|
||||||
|
= runRecurrentForwards net (S1D cell) (S1D input)
|
||||||
|
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||||
case actualBacks of
|
case actualBacks of
|
||||||
(_, _, S1D actualGradients :: S ('D1 3)) ->
|
(_, _, S1D actualGradients :: S ('D1 3)) ->
|
||||||
let refNet = Reference.lstmToReference lstmWeights
|
let refNet = Reference.lstmToReference lstmWeights
|
||||||
@ -93,7 +97,9 @@ prop_lstm_reference_backwards_cell =
|
|||||||
input :: S.R 3 <- forAll randomVector
|
input :: S.R 3 <- forAll randomVector
|
||||||
cell :: S.R 2 <- forAll randomVector
|
cell :: S.R 2 <- forAll randomVector
|
||||||
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
||||||
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
|
||||||
|
= runRecurrentForwards net (S1D cell) (S1D input)
|
||||||
|
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||||
case actualBacks of
|
case actualBacks of
|
||||||
(_, S1D actualGradients, _ :: S ('D1 3)) ->
|
(_, S1D actualGradients, _ :: S ('D1 3)) ->
|
||||||
let refNet = Reference.lstmToReference lstmWeights
|
let refNet = Reference.lstmToReference lstmWeights
|
||||||
|
Loading…
Reference in New Issue
Block a user