mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-21 21:59:30 +03:00
Optimise Wengert tape for LSTM
This commit is contained in:
parent
4332d62c71
commit
a5881518bf
@ -5,61 +5,22 @@ import Criterion.Main
|
||||
|
||||
import Grenade
|
||||
import Grenade.Recurrent
|
||||
import Grenade.Layers.Internal.Update
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as H
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
layer60 :: LSTM 40 60 <- createRandom
|
||||
layer512 :: LSTM 40 512 <- createRandom
|
||||
input40 :: S ('D1 40) <- randomOfShape
|
||||
rec60 :: S ('D1 60) <- randomOfShape
|
||||
rec512 :: S ('D1 512) <- randomOfShape
|
||||
lstm :: RecNet <- randomRecurrent
|
||||
|
||||
let upIn60 :: H.R 3600 = H.randomVector 1 H.Uniform * 2 - 1
|
||||
let upIn512 :: H.R 262144 = H.randomVector 1 H.Uniform * 2 - 1
|
||||
|
||||
defaultMain [
|
||||
bgroup "lstm" [ bench "forwards-60" $ nf (nfT3 . uncurry (testRun60 layer60)) (rec60, input40)
|
||||
, bench "forwards-512" $ nf (nfT3 . uncurry (testRun512 layer512)) (rec512, input40)
|
||||
, bench "backwards-60" $ nf (nfT3 . uncurry4 (testRun60' layer60)) (rec60, input40, rec60, rec60)
|
||||
, bench "backwards-512" $ nf (nfT3 . uncurry4 (testRun512' layer512)) (rec512, input40, rec512, rec512)
|
||||
]
|
||||
, bgroup "update" [ bench "matrix-60x60" $ nf (uncurry3 (descendVector 1 1 1)) (upIn60, upIn60, upIn60)
|
||||
, bench "matrix-512x512" $ nf (uncurry3 (descendVector 1 1 1)) (upIn512, upIn512, upIn512)
|
||||
]
|
||||
, bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
|
||||
bgroup "train" [ bench "one-time-step" $ whnf (nfT2 . trainRecurrent lp lstm 0) [(input40, Just input40)]
|
||||
, bench "ten-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 10 (input40, Just input40)
|
||||
, bench "fifty-time-steps" $ whnf (nfT2 . trainRecurrent lp lstm 0) $ replicate 50 (input40, Just input40)
|
||||
]
|
||||
]
|
||||
|
||||
testRun60 :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> ((S ('D1 60), S ('D1 40)), S ('D1 60), S ('D1 60))
|
||||
testRun60 = runRecurrentForwards
|
||||
|
||||
testRun60' :: LSTM 40 60 -> S ('D1 60) -> S ('D1 40) -> S ('D1 60) -> S ('D1 60) -> (Gradient (LSTM 40 60), S ('D1 60), S ('D1 40))
|
||||
testRun60' = curry . runRecurrentBackwards
|
||||
|
||||
testRun512 :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> ((S ('D1 512), S ('D1 40)), S ('D1 512), S ('D1 512))
|
||||
testRun512 = runRecurrentForwards
|
||||
|
||||
testRun512' :: LSTM 40 512 -> S ('D1 512) -> S ('D1 40) -> S ('D1 512) -> S ('D1 512) -> (Gradient (LSTM 40 512), S ('D1 512), S ('D1 40))
|
||||
testRun512' = curry . runRecurrentBackwards
|
||||
|
||||
uncurry4 :: (t -> t1 -> t2 -> t3 -> t4) -> (t, t1, t2, t3) -> t4
|
||||
uncurry4 f (a,b,c,d) = f a b c d
|
||||
|
||||
uncurry3 :: (t -> t1 -> t2 -> t3) -> (t, t1, t2) -> t3
|
||||
uncurry3 f (a,b,c) = f a b c
|
||||
|
||||
nfT2 :: (a, b) -> (a, b)
|
||||
nfT2 (!a, !b) = (a, b)
|
||||
|
||||
nfT3 :: (a, b, c) -> (b, c)
|
||||
nfT3 (!_, !b, !c) = (b, c)
|
||||
|
||||
|
||||
type R = Recurrent
|
||||
type RecNet = RecurrentNetwork '[ R (LSTM 40 512), R (LSTM 512 40) ]
|
||||
|
@ -36,12 +36,12 @@ instance Serialize Elu where
|
||||
get = return Elu
|
||||
|
||||
instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where
|
||||
type Tape Elu ('D1 i) ('D1 i) = S ('D1 i)
|
||||
type Tape Elu ('D1 i) ('D1 i) = LAS.R i
|
||||
|
||||
runForwards _ (S1D y) = (S1D y, S1D (elu y))
|
||||
runForwards _ (S1D y) = (y, S1D (elu y))
|
||||
where
|
||||
elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a)
|
||||
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (elu' y * dEdy))
|
||||
runBackwards _ y (S1D dEdy) = ((), S1D (elu' y * dEdy))
|
||||
where
|
||||
elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1)
|
||||
|
||||
|
@ -46,12 +46,12 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
|
||||
createRandom = randomFullyConnected
|
||||
|
||||
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
|
||||
type Tape (FullyConnected i o) ('D1 i) ('D1 o) = S ('D1 i)
|
||||
type Tape (FullyConnected i o) ('D1 i) ('D1 o) = R i
|
||||
-- Do a matrix vector multiplication and return the result.
|
||||
runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (S1D v, S1D (wB + wN #> v))
|
||||
runForwards (FullyConnected (FullyConnected' wB wN) _) (S1D v) = (v, S1D (wB + wN #> v))
|
||||
|
||||
-- Run a backpropogation step for a full connected layer.
|
||||
runBackwards (FullyConnected (FullyConnected' _ wN) _) (S1D x) (S1D dEdy) =
|
||||
runBackwards (FullyConnected (FullyConnected' _ wN) _) x (S1D dEdy) =
|
||||
let wB' = dEdy
|
||||
mm' = dEdy `outer` x
|
||||
-- calcluate derivatives for next step
|
||||
|
@ -34,18 +34,24 @@ instance UpdateLayer Logit where
|
||||
createRandom = return Logit
|
||||
|
||||
instance (a ~ b, SingI a) => Layer Logit a b where
|
||||
-- Wengert tape optimisation:
|
||||
--
|
||||
-- Derivative of the sigmoid function is
|
||||
-- d σ(x) / dx = σ(x) • (1 - σ(x))
|
||||
-- but we have already calculated σ(x) in
|
||||
-- the forward pass, so just store that
|
||||
-- and use it in the backwards pass.
|
||||
type Tape Logit a b = S a
|
||||
runForwards _ a = (a, logistic a)
|
||||
runBackwards _ a g = ((), logistic' a * g)
|
||||
runForwards _ a =
|
||||
let l = sigmoid a
|
||||
in (l, l)
|
||||
runBackwards _ l g =
|
||||
let sigmoid' = l * (1 - l)
|
||||
in ((), sigmoid' * g)
|
||||
|
||||
instance Serialize Logit where
|
||||
put _ = return ()
|
||||
get = return Logit
|
||||
|
||||
logistic :: Floating a => a -> a
|
||||
logistic x = 1 / (1 + exp (-x))
|
||||
|
||||
logistic' :: Floating a => a -> a
|
||||
logistic' x = logix * (1 - logix)
|
||||
where
|
||||
logix = logistic x
|
||||
sigmoid :: Floating a => a -> a
|
||||
sigmoid x = 1 / (1 + exp (-x))
|
||||
|
@ -127,37 +127,12 @@ instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
|
||||
|
||||
instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
|
||||
|
||||
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i))
|
||||
-- The tape stores essentially every variable we calculate,
|
||||
-- so we don't have to run any forwards component again.
|
||||
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
|
||||
-- Forward propagation for the LSTM layer.
|
||||
-- The size of the cell state is also the size of the output.
|
||||
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
|
||||
let -- Forget state vector
|
||||
f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
|
||||
-- Input state vector
|
||||
i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
|
||||
-- Output state vector
|
||||
o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
|
||||
-- Cell input state vector
|
||||
c_x = tanh $ lstmBc + lstmWc #> input
|
||||
-- Cell state
|
||||
c_t = f_t * cell + i_t * c_x
|
||||
-- Output (it's sometimes recommended to use tanh c_t)
|
||||
h_t = o_t * c_t
|
||||
in ((S1D cell, S1D input), S1D c_t, S1D h_t)
|
||||
|
||||
-- Run a backpropogation step for an LSTM layer.
|
||||
-- We're doing all the derivatives by hand here, so one should
|
||||
-- be extra careful when changing this.
|
||||
--
|
||||
-- There's a test version using the AD library without hmatrix in the test
|
||||
-- suite. These should match always.
|
||||
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell, S1D input) (S1D cellGrad) (S1D h_t') =
|
||||
-- We're not keeping the Wengert tape during the forward pass,
|
||||
-- so we're duplicating some work here.
|
||||
--
|
||||
-- If I was being generous, I'd call it checkpointing.
|
||||
--
|
||||
-- Maybe think about better ways to store some intermediate states.
|
||||
let -- Forget state vector
|
||||
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
|
||||
f_t = sigmoid f_s
|
||||
@ -172,8 +147,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
|
||||
c_x = tanh c_s
|
||||
-- Cell state
|
||||
c_t = f_t * cell + i_t * c_x
|
||||
-- Output (it's sometimes recommended to use tanh c_t)
|
||||
h_t = o_t * c_t
|
||||
in ((cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t), S1D c_t, S1D h_t)
|
||||
|
||||
-- Reverse Mode AD Derivitives
|
||||
-- Run a backpropogation step for an LSTM layer.
|
||||
-- We're doing all the derivatives by hand here, so one should
|
||||
-- be extra careful when changing this.
|
||||
--
|
||||
-- There's a test version using the AD library without hmatrix in the test
|
||||
-- suite. These should match always.
|
||||
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
|
||||
let -- Reverse Mode AD Derivitives
|
||||
c_t' = h_t' * o_t + cellGrad
|
||||
|
||||
f_t' = c_t' * cell
|
||||
@ -235,7 +220,8 @@ randomLSTM = do
|
||||
|
||||
-- | Maths
|
||||
--
|
||||
-- TODO: move to not here
|
||||
-- TODO: Move to not here
|
||||
-- Optimise backwards derivative
|
||||
sigmoid :: Floating a => a -> a
|
||||
sigmoid x = 1 / (1 + exp (-x))
|
||||
|
||||
|
@ -65,7 +65,9 @@ prop_lstm_reference_backwards =
|
||||
input :: S.R 3 <- forAll randomVector
|
||||
cell :: S.R 2 <- forAll randomVector
|
||||
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
||||
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
|
||||
= runRecurrentForwards net (S1D cell) (S1D input)
|
||||
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
case actualBacks of
|
||||
(actualGradients, _, _ :: S ('D1 3)) ->
|
||||
let refNet = Reference.lstmToReference lstmWeights
|
||||
@ -79,7 +81,9 @@ prop_lstm_reference_backwards_input =
|
||||
input :: S.R 3 <- forAll randomVector
|
||||
cell :: S.R 2 <- forAll randomVector
|
||||
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
||||
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
|
||||
= runRecurrentForwards net (S1D cell) (S1D input)
|
||||
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
case actualBacks of
|
||||
(_, _, S1D actualGradients :: S ('D1 3)) ->
|
||||
let refNet = Reference.lstmToReference lstmWeights
|
||||
@ -93,7 +97,9 @@ prop_lstm_reference_backwards_cell =
|
||||
input :: S.R 3 <- forAll randomVector
|
||||
cell :: S.R 2 <- forAll randomVector
|
||||
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
|
||||
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
let (tape, _ :: S ('D1 2), _ :: S ('D1 2))
|
||||
= runRecurrentForwards net (S1D cell) (S1D input)
|
||||
actualBacks = runRecurrentBackwards net tape (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
|
||||
case actualBacks of
|
||||
(_, S1D actualGradients, _ :: S ('D1 3)) ->
|
||||
let refNet = Reference.lstmToReference lstmWeights
|
||||
|
Loading…
Reference in New Issue
Block a user