Optimise Wengert tape for LSTM

This commit is contained in:
Huw Campbell 2017-12-16 21:21:58 +11:00
parent 4332d62c71
commit a5881518bf
6 changed files with 47 additions and 88 deletions

View File

@ -5,61 +5,22 @@ import Criterion.Main
import Grenade import Grenade
import Grenade.Recurrent import Grenade.Recurrent
import Grenade.Layers.Internal.Update
import qualified Numeric.LinearAlgebra.Static as H
main :: IO () main :: IO ()
main = do main = do
layer60 :: LSTM 40 60 <- createRandom
layer512 :: LSTM 40 512 <- createRandom
input40 :: S ('D1 40) <- randomOfShape input40 :: S ('D1 40) <- randomOfShape
rec60 :: S ('D1 60) <- randomOfShape
rec512 :: S ('D1 512) <- randomOfShape
lstm :: RecNet <- randomRecurrent lstm :: RecNet <- randomRecurrent
let upIn60 :: H.R 3600 = H.randomVector 1 H.Uniform * 2 - 1
let upIn512 :: H.R 262144 = H.randomVector 1 H.Uniform * 2 - 1
defaultMain [ defaultMain [
bgroup "lstm" [ bench "forwards-60" $ nf (nfT3 . uncurry (testRun60 layer60)) (rec60, input40) bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
, bench "forwards-512" $ nf (nfT3 . uncurry (testRun512 layer512)) (rec512, input40)
, bench "backwards-60" $ nf (nfT3 . uncurry4 (testRun60' layer60)) (rec60, input40, rec60, rec60)
, bench "backwards-512" $ nf (nfT3 . uncurry4 (testRun512' layer512)) (rec512, input40, rec512, rec512)
]
, bgroup "update" [ bench "matrix-60x60" $ nf (uncurry3 (descendVector 1 1 1)) (upIn60, upIn60, upIn60)
, bench "matrix-512x512" $ nf (uncurry3 (descendVector 1 1 1)) (upIn512, upIn512, upIn512)
]
, bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
, bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40) , bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40)
, bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40) , bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40)
] ]
] ]
testRun60 :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> ((S ('D1 60), S ('D1 40)), S ('D1 60), S ('D1 60))
testRun60 = runRecurrentForwards
testRun60' :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> S ('D1 60) -> S ('D1 60) -> (Gradient (LSTM 40 60), S ('D1 60), S ('D1 40))
testRun60' = curry . runRecurrentBackwards
testRun512 :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> ((S ('D1 512), S ('D1 40)), S ('D1 512), S ('D1 512))
testRun512 = runRecurrentForwards
testRun512' :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> S ('D1 512) -> S ('D1 512) -> (Gradient (LSTM 40 512), S ('D1 512), S ('D1 40))
testRun512' = curry . runRecurrentBackwards
uncurry4 :: (t -> t1 -> t2 -> t3 -> t4) -> (t, t1, t2, t3) -> t4
uncurry4 f (a,b,c,d) = f a b c d
uncurry3 :: (t -> t1 -> t2 -> t3) -> (t, t1, t2) -> t3
uncurry3 f (a,b,c) = f a b c
nfT2 :: (a, b) -> (a, b) nfT2 :: (a, b) -> (a, b)
nfT2 (!a, !b) = (a, b) nfT2 (!a, !b) = (a, b)
nfT3 :: (a, b, c) -> (b, c)
nfT3 (!_, !b, !c) = (b, c)
type R = Recurrent type R = Recurrent
type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ] type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ]

View File

@ -36,12 +36,12 @@ instance Serialize Elu where
get = return Elu get = return Elu
instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where
type Tape Elu ('D1 i) ('D1 i) = S ('D1 i) type Tape Elu ('D1 i) ('D1 i) = LAS.R i
runForwards _ (S1D y) = (S1D y, S1D (elu y)) runForwards _ (S1D y) = (y, S1D (elu y))
where where
elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a) elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a)
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (elu' y * dEdy)) runBackwards _ y (S1D dEdy) = ((), S1D (elu' y * dEdy))
where where
elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1) elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1)

View File

@ -46,12 +46,12 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
createRandom = randomFullyConnected createRandom = randomFullyConnected
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
type Tape (FullyConnected i o) ('D1 i) ('D1 o) = S ('D1 i) type Tape (FullyConnected i o) ('D1 i) ('D1 o) = R i
-- Do a matrix vector multiplication and return the result. -- Do a matrix vector multiplication and return the result.
runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (S1D v, S1D (wB + wN #> v)) runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (v, S1D (wB + wN #> v))
-- Run a backpropogation step for a full connected layer. -- Run a backpropogation step for a full connected layer.
runBackwards (FullyConnected (FullyConnected' _ wN) _) (S1D x) (S1D dEdy) = runBackwards (FullyConnected (FullyConnected' _ wN) _) x (S1D dEdy) =
let wB' = dEdy let wB' = dEdy
mm' = dEdy `outer` x mm' = dEdy `outer` x
-- calcluate derivatives for next step -- calcluate derivatives for next step

View File

@ -34,18 +34,24 @@ instance UpdateLayer Logit where
createRandom = return Logit createRandom = return Logit
instance (a ~ b, SingI a) => Layer Logit a b where instance (a ~ b, SingI a) => Layer Logit a b where
-- Wengert tape optimisation:
--
-- Derivative of the sigmoid function is
-- d σ(x) / dx = σ(x) • (1 - σ(x))
-- but we have already calculated σ(x) in
-- the forward pass, so just store that
-- and use it in the backwards pass.
type Tape Logit a b = S a type Tape Logit a b = S a
runForwards _ a = (a, logistic a) runForwards _ a =
runBackwards _ a g = ((), logistic' a * g) let l = sigmoid a
in (l, l)
runBackwards _ l g =
let sigmoid' = l * (1 - l)
in ((), sigmoid' * g)
instance Serialize Logit where instance Serialize Logit where
put _ = return () put _ = return ()
get = return Logit get = return Logit
logistic :: Floating a => a -> a sigmoid :: Floating a => a -> a
logistic x = 1 / (1 + exp (-x)) sigmoid x = 1 / (1 + exp (-x))
logistic' :: Floating a => a -> a
logistic' x = logix * (1 - logix)
where
logix = logistic x

View File

@ -127,37 +127,12 @@ instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i)) -- The tape stores essentially every variable we calculate,
-- so we don't have to run any forwards component again.
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
-- Forward propagation for the LSTM layer. -- Forward propagation for the LSTM layer.
-- The size of the cell state is also the size of the output. -- The size of the cell state is also the size of the output.
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) = runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
let -- Forget state vector
f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
-- Input state vector
i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
-- Output state vector
o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
-- Cell input state vector
c_x = tanh $ lstmBc + lstmWc #> input
-- Cell state
c_t = f_t * cell + i_t * c_x
-- Output (it's sometimes recommended to use tanh c_t)
h_t = o_t * c_t
in ((S1D cell, S1D input), S1D c_t, S1D h_t)
-- Run a backpropogation step for an LSTM layer.
-- We're doing all the derivatives by hand here, so one should
-- be extra careful when changing this.
--
-- There's a test version using the AD library without hmatrix in the test
-- suite. These should match always.
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell, S1D input) (S1D cellGrad) (S1D h_t') =
-- We're not keeping the Wengert tape during the forward pass,
-- so we're duplicating some work here.
--
-- If I was being generous, I'd call it checkpointing.
--
-- Maybe think about better ways to store some intermediate states.
let -- Forget state vector let -- Forget state vector
f_s = lstmBf + lstmWf #> input + lstmUf #> cell f_s = lstmBf + lstmWf #> input + lstmUf #> cell
f_t = sigmoid f_s f_t = sigmoid f_s
@ -172,8 +147,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
c_x = tanh c_s c_x = tanh c_s
-- Cell state -- Cell state
c_t = f_t * cell + i_t * c_x c_t = f_t * cell + i_t * c_x
-- Output (it's sometimes recommended to use tanh c_t)
h_t = o_t * c_t
in ((cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t), S1D c_t, S1D h_t)
-- Reverse Mode AD Derivitives -- Run a backpropogation step for an LSTM layer.
-- We're doing all the derivatives by hand here, so one should
-- be extra careful when changing this.
--
-- There's a test version using the AD library without hmatrix in the test
-- suite. These should match always.
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
let -- Reverse Mode AD Derivitives
c_t' = h_t' * o_t + cellGrad c_t' = h_t' * o_t + cellGrad
f_t' = c_t' * cell f_t' = c_t' * cell
@ -235,7 +220,8 @@ randomLSTM = do
-- | Maths -- | Maths
-- --
-- TODO: move to not here -- TODO: Move to not here
-- Optimise backwards derivative
sigmoid :: Floating a => a -> a sigmoid :: Floating a => a -> a
sigmoid x = 1 / (1 + exp (-x)) sigmoid x = 1 / (1 + exp (-x))

View File

@ -65,7 +65,9 @@ prop_lstm_reference_backwards =
input :: S.R 3 <- forAll randomVector input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
= runRecurrentForwards net (S1D cell) (S1D input)
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of case actualBacks of
(actualGradients, _, _ :: S ('D1 3)) -> (actualGradients, _, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights let refNet = Reference.lstmToReference lstmWeights
@ -79,7 +81,9 @@ prop_lstm_reference_backwards_input =
input :: S.R 3 <- forAll randomVector input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
= runRecurrentForwards net (S1D cell) (S1D input)
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of case actualBacks of
(_, _, S1D actualGradients :: S ('D1 3)) -> (_, _, S1D actualGradients :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights let refNet = Reference.lstmToReference lstmWeights
@ -93,7 +97,9 @@ prop_lstm_reference_backwards_cell =
input :: S.R 3 <- forAll randomVector input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
= runRecurrentForwards net (S1D cell) (S1D input)
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of case actualBacks of
(_, S1D actualGradients, _ :: S ('D1 3)) -> (_, S1D actualGradients, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights let refNet = Reference.lstmToReference lstmWeights