Remove use of RecordWildCards

This commit is contained in:
Erik de Castro Lopo 2020-04-11 17:18:23 +10:00
parent f19da0f74d
commit 5206c95c42
8 changed files with 91 additions and 104 deletions

View File

@ -1,6 +1,5 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
@ -83,34 +82,34 @@ loadShakespeare path = do
return (V.fromList hot, m, cs)
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
trainSlice !rate !net !recIns input offset size =
trainSlice !lrate !net !recIns input offset size =
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
in case reverse e of
(o : l : xs) ->
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
in trainRecurrent rate net recIns examples
in trainRecurrent lrate net recIns examples
_ -> error "Not enough input"
where
x = fromMaybe (error "Hot variable didn't fit.")
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
runShakespeare ShakespeareOpts {..} = do
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
runShakespeare opts = do
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare $ trainingFile opts
(net0, i0) <- lift $
case loadPath of
case loadPath opts of
Just loadFile -> netLoad loadFile
Nothing -> (,0) <$> randomNet
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
xs <- take (iterations `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
results <- take 1000 <$> generateParagraph trained bestInput temperature oneHotMap oneHotDictionary ( S1D $ konst 0)
xs <- take (iterations opts `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice (rate opts) n i shakespeare offset size) (net, io) xs
results <- take 1000 <$> generateParagraph trained bestInput (temperature opts) oneHotMap oneHotDictionary ( S1D $ konst 0)
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
putStrLn (unAnnotateCapitals results)
return (trained, bestInput)
) (net0, i0) $ replicate 10 sequenceSize
) (net0, i0) $ replicate 10 (sequenceSize opts)
case savePath of
case savePath opts of
Just saveFile -> lift . B.writeFile saveFile $ runPut (put trained >> put bestInput)
Nothing -> return ()
@ -122,12 +121,12 @@ generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes
-> Vector a
-> S ('D1 n)
-> IO [a]
generateParagraph n s temperature hotmap hotdict =
generateParagraph n s temp hotmap hotdict =
go s
where
go x y =
do let (_, ns, o) = runRecurrent n x y
un <- sample temperature hotdict o
un <- sample temp hotdict o
Just re <- return $ makeHot hotmap un
rest <- unsafeInterleaveIO $ go ns re
return (un : rest)

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
@ -139,8 +138,8 @@ instance ( KnownNat channels
, KnownNat (kernelRows * kernelColumns * channels)
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Convolution' channels filters kernelRows kernelColumns strideRows strideColumns)
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
runUpdate lp (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
in Convolution newKernel newMomentum
createRandom = randomConvolution

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
@ -138,8 +137,8 @@ instance ( KnownNat channels
, KnownNat (kernelRows * kernelColumns * filters)
) => UpdateLayer (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Deconvolution channels filters kernelRows kernelColumns strideRows strideColumns) = (Deconvolution' channels filters kernelRows kernelColumns strideRows strideColumns)
runUpdate LearningParameters {..} (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
runUpdate lp (Deconvolution oldKernel oldMomentum) (Deconvolution' kernelGradient) =
let (newKernel, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldKernel kernelGradient oldMomentum
in Deconvolution newKernel newMomentum
createRandom = randomDeconvolution

View File

@ -1,5 +1,4 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@ -38,9 +37,9 @@ instance Show (FullyConnected i o) where
instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
type Gradient (FullyConnected i o) = (FullyConnected' i o)
runUpdate LearningParameters {..} (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
let (newBias, newBiasMomentum) = descendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum
(newActivations, newMomentum) = descendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum
runUpdate lp (FullyConnected (FullyConnected' oldBias oldActivations) (FullyConnected' oldBiasMomentum oldMomentum)) (FullyConnected' biasGradient activationGradient) =
let (newBias, newBiasMomentum) = descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldBias biasGradient oldBiasMomentum
(newActivations, newMomentum) = descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) oldActivations activationGradient oldMomentum
in FullyConnected (FullyConnected' newBias newActivations) (FullyConnected' newBiasMomentum newMomentum)
createRandom = randomFullyConnected

View File

@ -6,7 +6,6 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Grenade.Recurrent.Core.Runner (
runRecurrentExamples
@ -101,11 +100,11 @@ updateRecInputs :: Fractional (RecurrentInputs sublayers)
-> RecurrentInputs sublayers
-> RecurrentInputs sublayers
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
= () :~~+> updateRecInputs l xs ys
updateRecInputs lp (() :~~+> xs) (() :~~+> ys)
= () :~~+> updateRecInputs lp xs ys
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
= (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
updateRecInputs lp (x :~@+> xs) (y :~@+> ys)
= (realToFrac (1 - learningRate lp * learningRegulariser lp) * x - realToFrac (learningRate lp) * y) :~@+> updateRecInputs lp xs ys
updateRecInputs _ RINil RINil
= RINil

View File

@ -1,7 +1,6 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@ -60,11 +59,11 @@ instance Show (BasicRecurrent i o) where
instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
runUpdate lp (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
let newBiasMomentum = konst (learningMomentum lp) * oldBiasMomentum - konst (learningRate lp) * biasGradient
newBias = oldBias + newBiasMomentum
newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
regulariser = konst (learningRegulariser * learningRate) * oldActivations
newMomentum = konst (learningMomentum lp) * oldMomentum - konst (learningRate lp) * activationGradient
regulariser = konst (learningRegulariser lp * learningRate lp) * oldActivations
newActivations = oldActivations + newMomentum - regulariser
in BasicRecurrent newBias newBiasMomentum newActivations newMomentum

View File

@ -3,7 +3,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
@ -75,7 +74,7 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
-- Run the update function for each group matrix/vector of weights, momentums and gradients.
-- Hmm, maybe the function should be used instead of passing in the learning parameters.
runUpdate LearningParameters {..} (LSTM w m) g =
runUpdate lp (LSTM w m) g =
let (wf, wf') = u lstmWf w m g
(uf, uf') = u lstmUf w m g
(bf, bf') = v lstmBf w m g
@ -92,11 +91,11 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
-- Utility function for updating with the momentum, gradients, and weights.
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
u e (e -> weights) (e -> momentum) (e -> gradient) =
descendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
descendMatrix (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
v e (e -> weights) (e -> momentum) (e -> gradient) =
descendVector learningRate learningMomentum learningRegulariser weights gradient momentum
descendVector (learningRate lp) (learningMomentum lp) (learningRegulariser lp) weights gradient momentum
-- There's a lot of updates here, so to try and minimise the number of data copies
-- we'll create a mutable bucket for each.
@ -137,18 +136,18 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (R o, R i, R o, R o, R o, R o, R o, R o, R o, R o, R o)
-- Forward propagation for the LSTM layer.
-- The size of the cell state is also the size of the output.
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
runRecurrentForwards (LSTM lw _) (S1D cell) (S1D input) =
let -- Forget state vector
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
f_s = lstmBf lw + lstmWf lw #> input + lstmUf lw #> cell
f_t = sigmoid f_s
-- Input state vector
i_s = lstmBi + lstmWi #> input + lstmUi #> cell
i_s = lstmBi lw + lstmWi lw #> input + lstmUi lw #> cell
i_t = sigmoid i_s
-- Output state vector
o_s = lstmBo + lstmWo #> input + lstmUo #> cell
o_s = lstmBo lw + lstmWo lw #> input + lstmUo lw #> cell
o_t = sigmoid o_s
-- Cell input state vector
c_s = lstmBc + lstmWc #> input
c_s = lstmBc lw + lstmWc lw #> input
c_x = tanh c_s
-- Cell state
c_t = f_t * cell + i_t * c_x
@ -162,7 +161,7 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
--
-- There's a test version using the AD library without hmatrix in the test
-- suite. These should match always.
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
runRecurrentBackwards (LSTM lw _) (cell, input, f_s, f_t, i_s, i_t, o_s, o_t, c_s, c_x, c_t) (S1D cellGrad) (S1D h_t') =
let -- Reverse Mode AD Derivitives
c_t' = h_t' * o_t + cellGrad
@ -179,8 +178,8 @@ instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) w
c_s' = tanh' c_s * c_x'
-- The derivatives to pass sideways (recurrent) and downwards
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'
cell' = tr (lstmUf lw) #> f_s' + tr (lstmUo lw) #> o_s' + tr (lstmUi lw) #> i_s' + c_t' * f_t
input' = tr (lstmWf lw) #> f_s' + tr (lstmWo lw) #> o_s' + tr (lstmWi lw) #> i_s' + tr (lstmWc lw) #> c_s'
-- Calculate the gradient Matricies for the input
lstmWf' = f_s' `outer` input
@ -239,44 +238,34 @@ tanh' :: (Floating a) => a -> a
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
put (LSTM LSTMWeights {..} _) = do
u lstmWf
u lstmUf
v lstmBf
u lstmWi
u lstmUi
v lstmBi
u lstmWo
u lstmUo
v lstmBo
u lstmWc
v lstmBc
where
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
u = putListOf put . LA.toList . LA.flatten . extract
v :: forall a. (KnownNat a) => Putter (R a)
v = putListOf put . LA.toList . extract
put (LSTM lw _) = do
u (lstmWf lw)
u (lstmUf lw)
v (lstmBf lw)
u (lstmWi lw)
u (lstmUi lw)
v (lstmBi lw)
u (lstmWo lw)
u (lstmUo lw)
v (lstmBo lw)
u (lstmWc lw)
v (lstmBc lw)
where
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
u = putListOf put . LA.toList . LA.flatten . extract
v :: forall a. (KnownNat a) => Putter (R a)
v = putListOf put . LA.toList . extract
get = do
lstmWf <- u
lstmUf <- u
lstmBf <- v
lstmWi <- u
lstmUi <- u
lstmBi <- v
lstmWo <- u
lstmUo <- u
lstmBo <- v
lstmWc <- u
lstmBc <- v
return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
where
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
v :: forall a. (KnownNat a) => Get (R a)
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
w <- LSTMWeights <$> u <*> u <*> v <*> u <*> u <*> v <*> u <*> u <*> v <*> u <*> v
return $ LSTM w (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
where
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
v :: forall a. (KnownNat a) => Get (R a)
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
w0 = konst 0
u0 = konst 0
v0 = konst 0
w0 = konst 0
u0 = konst 0
v0 = konst 0

View File

@ -6,7 +6,6 @@
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
@ -15,9 +14,13 @@ module Test.Grenade.Recurrent.Layers.LSTM.Reference where
import Data.Reflection
import Numeric.AD.Mode.Reverse
import Numeric.AD.Internal.Reverse ( Tape )
import Numeric.AD.Internal.Reverse (Tape)
import GHC.TypeLits (KnownNat)
import Grenade.Recurrent.Layers.LSTM (LSTMWeights (..))
import qualified Grenade.Recurrent.Layers.LSTM as LSTM
import qualified Numeric.LinearAlgebra.Static as S
import qualified Numeric.LinearAlgebra as H
@ -54,31 +57,32 @@ data RefLSTM a = RefLSTM
, refLstmBc :: Vector a -- Bias Cell (b_c)
} deriving (Functor, Foldable, Traversable, Eq, Show)
lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double
lstmToReference LSTM.LSTMWeights {..} =
let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f)
refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f)
refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f)
refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i)
refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i)
refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i)
refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o)
refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o)
refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o)
refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c)
refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c)
in RefLSTM {..}
lstmToReference :: (KnownNat a, KnownNat b) => LSTM.LSTMWeights a b -> RefLSTM Double
lstmToReference lw =
RefLSTM
{ refLstmWf = Matrix . H.toLists . S.extract $ lstmWf lw -- Weight Forget (W_f)
, refLstmUf = Matrix . H.toLists . S.extract $ lstmUf lw -- Cell State Forget (U_f)
, refLstmBf = Vector . H.toList . S.extract $ lstmBf lw -- Bias Forget (b_f)
, refLstmWi = Matrix . H.toLists . S.extract $ lstmWi lw -- Weight Input (W_i)
, refLstmUi = Matrix . H.toLists . S.extract $ lstmUi lw -- Cell State Input (U_i)
, refLstmBi = Vector . H.toList . S.extract $ lstmBi lw -- Bias Input (b_i)
, refLstmWo = Matrix . H.toLists . S.extract $ lstmWo lw -- Weight Output (W_o)
, refLstmUo = Matrix . H.toLists . S.extract $ lstmUo lw -- Cell State Output (U_o)
, refLstmBo = Vector . H.toList . S.extract $ lstmBo lw -- Bias Output (b_o)
, refLstmWc = Matrix . H.toLists . S.extract $ lstmWc lw -- Weight Cell (W_c)
, refLstmBc = Vector . H.toList . S.extract $ lstmBc lw -- Bias Cell (b_c)
}
runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
runLSTM RefLSTM {..} cell input =
runLSTM rl cell input =
let -- Forget state vector
f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell
f_t = sigmoid $ refLstmBf rl #+ refLstmWf rl #> input #+ refLstmUf rl #> cell
-- Input state vector
i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell
i_t = sigmoid $ refLstmBi rl #+ refLstmWi rl #> input #+ refLstmUi rl #> cell
-- Output state vector
o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell
o_t = sigmoid $ refLstmBo rl #+ refLstmWo rl #> input #+ refLstmUo rl #> cell
-- Cell input state vector
c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input
c_x = fmap tanh $ refLstmBc rl #+ refLstmWc rl #> input
-- Cell state
c_t = f_t #* cell #+ i_t #* c_x
-- Output (it's sometimes recommended to use tanh c_t)