From a5881518bf69d6033e0ccfae21f7b3ea509305a3 Mon Sep 17 00:00:00 2001 From: Huw Campbell Date: Sat, 16 Dec 2017 21:21:58 +1100 Subject: [PATCH] Optimise Wengert tape for LSTM --- bench/bench-lstm.hs | 41 +------------------ src/Grenade/Layers/Elu.hs | 6 +-- src/Grenade/Layers/FullyConnected.hs | 6 +-- src/Grenade/Layers/Logit.hs | 24 ++++++----- src/Grenade/Recurrent/Layers/LSTM.hs | 46 ++++++++-------------- test/Test/Grenade/Recurrent/Layers/LSTM.hs | 12 ++++-- 6 files changed, 47 insertions(+), 88 deletions(-) diff --git a/bench/bench-lstm.hs b/bench/bench-lstm.hs index 4c2afc6..8018516 100644 --- a/bench/bench-lstm.hs +++ b/bench/bench-lstm.hs @@ -5,61 +5,22 @@ import Criterion.Main import Grenade import Grenade.Recurrent -import Grenade.Layers.Internal.Update - -import qualified Numeric.LinearAlgebra.Static as H main :: IO () main = do - layer60 :: LSTM 40 60 <- createRandom - layer512 :: LSTM 40 512 <- createRandom input40 :: S ('D1 40) <- randomOfShape - rec60 :: S ('D1 60) <- randomOfShape - rec512 :: S ('D1 512) <- randomOfShape lstm :: RecNet <- randomRecurrent - let upIn60 :: H.R 3600 = H.randomVector 1 H.Uniform * 2 - 1 - let upIn512 :: H.R 262144 = H.randomVector 1 H.Uniform * 2 - 1 - defaultMain [ - bgroup "lstm" [ bench "forwards-60" $ nf (nfT3 . uncurry (testRun60 layer60)) (rec60, input40) - , bench "forwards-512" $ nf (nfT3 . uncurry (testRun512 layer512)) (rec512, input40) - , bench "backwards-60" $ nf (nfT3 . uncurry4 (testRun60' layer60)) (rec60, input40, rec60, rec60) - , bench "backwards-512" $ nf (nfT3 . uncurry4 (testRun512' layer512)) (rec512, input40, rec512, rec512) - ] - , bgroup "update" [ bench "matrix-60x60" $ nf (uncurry3 (descendVector 1 1 1)) (upIn60, upIn60, upIn60) - , bench "matrix-512x512" $ nf (uncurry3 (descendVector 1 1 1)) (upIn512, upIn512, upIn512) - ] - , bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)] + bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)] , bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40) , bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40) ] ] -testRun60 :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> ((S ('D1 60), S ('D1 40)), S ('D1 60), S ('D1 60)) -testRun60 = runRecurrentForwards - -testRun60' :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> S ('D1 60) -> S ('D1 60) -> (Gradient (LSTM 40 60), S ('D1 60), S ('D1 40)) -testRun60' = curry . runRecurrentBackwards - -testRun512 :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> ((S ('D1 512), S ('D1 40)), S ('D1 512), S ('D1 512)) -testRun512 = runRecurrentForwards - -testRun512' :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> S ('D1 512) -> S ('D1 512) -> (Gradient (LSTM 40 512), S ('D1 512), S ('D1 40)) -testRun512' = curry . runRecurrentBackwards - -uncurry4 :: (t -> t1 -> t2 -> t3 -> t4) -> (t, t1, t2, t3) -> t4 -uncurry4 f (a,b,c,d) = f a b c d - -uncurry3 :: (t -> t1 -> t2 -> t3) -> (t, t1, t2) -> t3 -uncurry3 f (a,b,c) = f a b c - nfT2 :: (a, b) -> (a, b) nfT2 (!a, !b) = (a, b) -nfT3 :: (a, b, c) -> (b, c) -nfT3 (!_, !b, !c) = (b, c) - type R = Recurrent type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ] diff --git a/src/Grenade/Layers/Elu.hs b/src/Grenade/Layers/Elu.hs index 1d2f06f..a109086 100644 --- a/src/Grenade/Layers/Elu.hs +++ b/src/Grenade/Layers/Elu.hs @@ -36,12 +36,12 @@ instance Serialize Elu where get = return Elu instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where - type Tape Elu ('D1 i) ('D1 i) = S ('D1 i) + type Tape Elu ('D1 i) ('D1 i) = LAS.R i - runForwards _ (S1D y) = (S1D y, S1D (elu y)) + runForwards _ (S1D y) = (y, S1D (elu y)) where elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a) - runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (elu' y * dEdy)) + runBackwards _ y (S1D dEdy) = ((), S1D (elu' y * dEdy)) where elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1) diff --git a/src/Grenade/Layers/FullyConnected.hs b/src/Grenade/Layers/FullyConnected.hs index f96dcd7..ecaad46 100644 --- a/src/Grenade/Layers/FullyConnected.hs +++ b/src/Grenade/Layers/FullyConnected.hs @@ -46,12 +46,12 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where createRandom = randomFullyConnected instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where - type Tape (FullyConnected i o) ('D1 i) ('D1 o) = S ('D1 i) + type Tape (FullyConnected i o) ('D1 i) ('D1 o) = R i -- Do a matrix vector multiplication and return the result. - runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (S1D v, S1D (wB + wN #> v)) + runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (v, S1D (wB + wN #> v)) -- Run a backpropogation step for a full connected layer. - runBackwards (FullyConnected (FullyConnected' _ wN) _) (S1D x) (S1D dEdy) = + runBackwards (FullyConnected (FullyConnected' _ wN) _) x (S1D dEdy) = let wB' = dEdy mm' = dEdy `outer` x -- calcluate derivatives for next step diff --git a/src/Grenade/Layers/Logit.hs b/src/Grenade/Layers/Logit.hs index 9097ee5..cc58482 100644 --- a/src/Grenade/Layers/Logit.hs +++ b/src/Grenade/Layers/Logit.hs @@ -34,18 +34,24 @@ instance UpdateLayer Logit where createRandom = return Logit instance (a ~ b, SingI a) => Layer Logit a b where + -- Wengert tape optimisation: + -- + -- Derivative of the sigmoid function is + -- d σ(x) / dx = σ(x) • (1 - σ(x)) + -- but we have already calculated σ(x) in + -- the forward pass, so just store that + -- and use it in the backwards pass. type Tape Logit a b = S a - runForwards _ a = (a, logistic a) - runBackwards _ a g = ((), logistic' a * g) + runForwards _ a = + let l = sigmoid a + in (l, l) + runBackwards _ l g = + let sigmoid' = l * (1 - l) + in ((), sigmoid' * g) instance Serialize Logit where put _ = return () get = return Logit -logistic :: Floating a => a -> a -logistic x = 1 / (1 + exp (-x)) - -logistic' :: Floating a => a -> a -logistic' x = logix * (1 - logix) - where - logix = logistic x +sigmoid :: Floating a => a -> a +sigmoid x = 1 / (1 + exp (-x)) diff --git a/src/Grenade/Recurrent/Layers/LSTM.hs b/src/Grenade/Recurrent/Layers/LSTM.hs index cb49ad3..a3d6993 100644 --- a/src/Grenade/Recurrent/Layers/LSTM.hs +++ b/src/Grenade/Recurrent/Layers/LSTM.hs @@ -127,37 +127,12 @@ instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where - type RecTape (LSTM i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i)) + -- The tape stores essentially every variable we calculate, + -- so we don't have to run any forwards component again. + type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o) -- Forward propagation for the LSTM layer. -- The size of the cell state is also the size of the output. runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) = - let -- Forget state vector - f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell - -- Input state vector - i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell - -- Output state vector - o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell - -- Cell input state vector - c_x = tanh $ lstmBc + lstmWc #> input - -- Cell state - c_t = f_t * cell + i_t * c_x - -- Output (it's sometimes recommended to use tanh c_t) - h_t = o_t * c_t - in ((S1D cell, S1D input), S1D c_t, S1D h_t) - - -- Run a backpropogation step for an LSTM layer. - -- We're doing all the derivatives by hand here, so one should - -- be extra careful when changing this. - -- - -- There's a test version using the AD library without hmatrix in the test - -- suite. These should match always. - runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell, S1D input) (S1D cellGrad) (S1D h_t') = - -- We're not keeping the Wengert tape during the forward pass, - -- so we're duplicating some work here. - -- - -- If I was being generous, I'd call it checkpointing. - -- - -- Maybe think about better ways to store some intermediate states. let -- Forget state vector f_s = lstmBf + lstmWf #> input + lstmUf #> cell f_t = sigmoid f_s @@ -172,8 +147,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w c_x = tanh c_s -- Cell state c_t = f_t * cell + i_t * c_x + -- Output (it's sometimes recommended to use tanh c_t) + h_t = o_t * c_t + in ((cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t), S1D c_t, S1D h_t) - -- Reverse Mode AD Derivitives + -- Run a backpropogation step for an LSTM layer. + -- We're doing all the derivatives by hand here, so one should + -- be extra careful when changing this. + -- + -- There's a test version using the AD library without hmatrix in the test + -- suite. These should match always. + runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') = + let -- Reverse Mode AD Derivitives c_t' = h_t' * o_t + cellGrad f_t' = c_t' * cell @@ -235,7 +220,8 @@ randomLSTM = do -- | Maths -- --- TODO: move to not here +-- TODO: Move to not here +-- Optimise backwards derivative sigmoid :: Floating a => a -> a sigmoid x = 1 / (1 + exp (-x)) diff --git a/test/Test/Grenade/Recurrent/Layers/LSTM.hs b/test/Test/Grenade/Recurrent/Layers/LSTM.hs index 4f20fb3..4722f4f 100644 --- a/test/Test/Grenade/Recurrent/Layers/LSTM.hs +++ b/test/Test/Grenade/Recurrent/Layers/LSTM.hs @@ -65,7 +65,9 @@ prop_lstm_reference_backwards = input :: S.R 3 <- forAll randomVector cell :: S.R 2 <- forAll randomVector net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM - let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) + let (tape, _ :: S ('D1 2), _ :: S ('D1 2)) + = runRecurrentForwards net (S1D cell) (S1D input) + actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) case actualBacks of (actualGradients, _, _ :: S ('D1 3)) -> let refNet = Reference.lstmToReference lstmWeights @@ -79,7 +81,9 @@ prop_lstm_reference_backwards_input = input :: S.R 3 <- forAll randomVector cell :: S.R 2 <- forAll randomVector net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM - let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) + let (tape, _ :: S ('D1 2), _ :: S ('D1 2)) + = runRecurrentForwards net (S1D cell) (S1D input) + actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) case actualBacks of (_, _, S1D actualGradients :: S ('D1 3)) -> let refNet = Reference.lstmToReference lstmWeights @@ -93,7 +97,9 @@ prop_lstm_reference_backwards_cell = input :: S.R 3 <- forAll randomVector cell :: S.R 2 <- forAll randomVector net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM - let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) + let (tape, _ :: S ('D1 2), _ :: S ('D1 2)) + = runRecurrentForwards net (S1D cell) (S1D input) + actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) case actualBacks of (_, S1D actualGradients, _ :: S ('D1 3)) -> let refNet = Reference.lstmToReference lstmWeights