Update LSTM

This commit is contained in:
Huw Campbell 2017-01-19 18:55:13 +11:00
parent 4dc408f39d
commit bcd1856988
11 changed files with 200 additions and 65 deletions

View File

@ -23,9 +23,18 @@ randomMnist :: MonadRandom m => m MNIST
randomMnist = randomNetwork
```
And that's it. Because the types are rich, there's no specific term level code
required; although it is of course possible and easy to construct one explicitly
oneself.
And that's it. Because the types are so rich, there's no specific term level code
required to construct this network; although it is of course possible and
easy to construct and deconstruct the networks and layers explicitly oneself.
If recurrent neural networks are more your style, you can try defining something
["unreasonably effective"](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)
with
```haskell
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit]
'[ 'D1 40, 'D1 80, 'D1 40, 'D1 40, 'D1 40 ]
```
The network can be thought of as a heterogeneous list of layers, where its type
includes not only the layers of the network, but also the shapes of data that
@ -107,13 +116,8 @@ Being purely functional, it should also be easy to run batches in parallel, whic
would be appropriate for larger networks, my current examples however are single
threaded.
<<<<<<< 20e7e483d75613f16580baa71b44fa9864c940fd
Training 15 generations over Kaggle's 41000 sample MNIST training set on a single
core took around 12 minutes, achieving 1.5% error rate on a 1000 sample holdout set.
=======
Training 15 generations over Kaggle's 42000 sample MNIST training set took under
an hour on my laptop, achieving 0.5% error rate on a 1000 sample holdout set.
>>>>>>> Add very basic model saving and loading
Contributing
------------

View File

@ -135,6 +135,8 @@ executable shakespeare
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.12.*
, text == 1.2.*

View File

@ -11,7 +11,7 @@ import Control.Monad.Random
import Control.Monad.Trans.Except
import Data.Char ( isUpper, toUpper, toLower )
import Data.List ( unfoldr, foldl' )
import Data.List ( foldl' )
import Data.Maybe ( fromMaybe )
import qualified Data.Vector as V
@ -20,6 +20,8 @@ import Data.Vector ( Vector )
import qualified Data.Map as M
import Data.Proxy ( Proxy (..) )
import qualified Data.ByteString as B
import Data.Serialize
import Data.Singletons.Prelude
import GHC.TypeLits
@ -32,29 +34,31 @@ import Grenade
import Grenade.Recurrent
import Grenade.Utils.OneHot
import System.IO.Unsafe ( unsafeInterleaveIO )
-- The defininition for our natural language recurrent network.
-- This network is able to learn and generate simple words in
-- about an hour.
--
-- This is a first class recurrent net, although it's similar to
-- an unrolled graph.
-- This is a first class recurrent net.
--
-- The F and R types are tagging types to ensure that the runner and
-- creation function know how to treat the layers.
--
-- As an example, here's a short sequence generated.
--
-- > the see and and the sir, and and the make and the make and go the make and go the make and the
--
-- > KING RICHARD III:
-- > And as the heaven her his words, we the son, I show sand stape but the lament to shall were the sons with a strend
type F = FeedForward
type R = Recurrent
-- The definition of our network
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 50), R (LSTM 50 40), F (FullyConnected 40 40), F Logit]
'[ 'D1 40, 'D1 50, 'D1 40, 'D1 40, 'D1 40 ]
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit]
'[ 'D1 40, 'D1 80, 'D1 40, 'D1 40, 'D1 40 ]
-- The definition of the "sideways" input, which the network if fed recurrently.
type Shakespearian = RecurrentInputs '[ R (LSTM 40 50), R (LSTM 50 40), F (FullyConnected 40 40), F Logit]
type Shakespearian = RecurrentInputs '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit]
randomNet :: MonadRandom m => m (Shakespeare, Shakespearian)
randomNet = randomRecurrent
@ -82,15 +86,23 @@ trainSlice !rate !net !recIns input offset size =
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
runShakespeare ShakespeareOpts {..} = do
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
(net0, i0) <- lift randomNet
lift $ foldM_ (\(!net, !io) size -> do
(net0, i0) <- lift $
case loadPath of
Just loadFile -> netLoad loadFile
Nothing -> randomNet
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
xs <- take (iterations `div` 15) <$> getRandomRs (0, length shakespeare - size - 1)
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
let results = take 100 $ generateParagraph trained bestInput oneHotMap oneHotDictionary ( S1D $ konst 0)
results <- take 1000 <$> generateParagraph trained bestInput oneHotMap oneHotDictionary ( S1D $ konst 0)
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
putStrLn (unAnnotateCapitals results)
return (trained, bestInput)
) (net0, i0) [10,10,15,15,20,20,25,25,30,30,35,35,40,40,50 :: Int]
) (net0, i0) [50,50,50,50,50,50,50,50,50,50,50,50,50,50,50 :: Int]
case savePath of
Just saveFile -> lift . B.writeFile saveFile $ runPut (put (trained, bestInput))
Nothing -> return ()
generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes ~ 'D1 n, KnownNat n, Ord a)
=> RecurrentNetwork layers shapes
@ -98,20 +110,23 @@ generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes
-> M.Map a Int
-> Vector a
-> S ('D1 n)
-> [a]
generateParagraph n s hotmap hotdict i =
unfoldr go (s, i)
-> IO [a]
generateParagraph n s hotmap hotdict =
go s
where
go (x, y) =
go x y =
do let (ns, o) = runRecurrent n x y
un <- unHot hotdict o
re <- makeHot hotmap un
Just (un, (ns, re))
un <- sample 0.4 hotdict o
Just re <- return $ makeHot hotmap un
rest <- unsafeInterleaveIO $ go ns re
return (un : rest)
data ShakespeareOpts = ShakespeareOpts {
trainingFile :: FilePath
, iterations :: Int
, rate :: LearningParameters
, loadPath :: Maybe FilePath
, savePath :: Maybe FilePath
}
shakespeare' :: Parser ShakespeareOpts
@ -122,6 +137,8 @@ shakespeare' = ShakespeareOpts <$> argument str (metavar "TRAIN")
<*> option auto (long "momentum" <> value 0.95)
<*> option auto (long "l2" <> value 0.000001)
)
<*> optional (strOption (long "load"))
<*> optional (strOption (long "save"))
main :: IO ()
main = do
@ -132,6 +149,11 @@ main = do
Left err -> putStrLn err
netLoad :: FilePath -> IO (Shakespeare, Shakespearian)
netLoad modelPath = do
modelData <- B.readFile modelPath
either fail return $ runGet get modelData
-- Replace capitals with an annotation and the lower case letter
-- http://fastml.com/one-weird-trick-for-training-char-rnns/
annotateCapitals :: String -> String

View File

@ -65,30 +65,36 @@ data S (n :: Shape) where
deriving instance Show (S n)
-- Singletons
-- These could probably be derived with template haskell, but this seems
-- clear and makes adding the KnownNat constraints simple.
data instance Sing (n :: Shape) where
D1Sing :: KnownNat a => Sing ('D1 a)
D2Sing :: (KnownNat a, KnownNat b) => Sing ('D2 a b)
D3Sing :: (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => Sing ('D3 a b c)
instance KnownNat a => SingI ('D1 a) where
sing = D1Sing
instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
sing = D2Sing
instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
sing = D3Sing
instance SingI x => Num (S x) where
(+) = n2 (+)
(-) = n2 (-)
(*) = n2 (*)
abs = n1 abs
signum = n1 signum
fromInteger x = case (sing :: Sing x) of
D1Sing -> S1D (konst $ fromInteger x)
D2Sing -> S2D (konst $ fromInteger x)
D3Sing -> S3D (konst $ fromInteger x)
fromInteger x = nk (fromInteger x)
instance SingI x => Fractional (S x) where
(/) = n2 (/)
recip = n1 recip
fromRational x = case (sing :: Sing x) of
D1Sing -> S1D (konst $ fromRational x)
D2Sing -> S2D (konst $ fromRational x)
D3Sing -> S3D (konst $ fromRational x)
fromRational x = nk (fromRational x)
instance SingI x => Floating (S x) where
pi = case (sing :: Sing x) of
D1Sing -> S1D (konst pi)
D2Sing -> S2D (konst pi)
D3Sing -> S3D (konst pi)
pi = nk pi
exp = n1 exp
log = n1 log
sqrt = n1 sqrt
@ -107,21 +113,6 @@ instance SingI x => Floating (S x) where
acosh = n1 acosh
atanh = n1 atanh
-- Singletons
-- These could probably be derived with template haskell, but this seems
-- clear and makes adding the KnownNat constraints simple.
data instance Sing (n :: Shape) where
D1Sing :: KnownNat a => Sing ('D1 a)
D2Sing :: (KnownNat a, KnownNat b) => Sing ('D2 a b)
D3Sing :: (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => Sing ('D3 a b c)
instance KnownNat a => SingI ('D1 a) where
sing = D1Sing
instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
sing = D2Sing
instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
sing = D3Sing
--
-- I haven't made shapes strict, as sometimes they're not needed
-- (the last input gradient back for instance)
@ -170,3 +161,10 @@ n2 f (S1D x) (S1D y) = S1D (f x y)
n2 f (S2D x) (S2D y) = S2D (f x y)
n2 f (S3D x) (S3D y) = S3D (f x y)
n2 _ _ _ = error "Impossible to have different constructors for the same shaped network"
-- Helper function for creating the number instances
nk :: forall x. SingI x => Double -> S x
nk x = case (sing :: Sing x) of
D1Sing -> S1D (konst x)
D2Sing -> S2D (konst x)
D3Sing -> S3D (konst x)

View File

@ -19,10 +19,15 @@ module Grenade.Recurrent.Core.Network (
import Control.Monad.Random ( MonadRandom )
import Data.Singletons ( SingI )
import Data.Serialize
import qualified Data.Vector.Storable as V
import Grenade.Core.Shape
import Grenade.Core.Network
import qualified Numeric.LinearAlgebra as LA
import qualified Numeric.LinearAlgebra.Static as LAS
-- | Witness type to say indicate we're building up with a normal feed
-- forward layer.
@ -96,3 +101,38 @@ instance (SingI i, RecurrentLayer x i o, CreatableRecurrent xs (o ': r ': rs)) =
(rest, resti) <- randomRecurrent
return (thisLayer :~@> rest, thisShape :~@+> resti)
-- | Add very simple serialisation to the recurrent network
instance (SingI i, SingI o, Layer x i o, Serialize x) => Serialize (RecurrentNetwork '[FeedForward x] '[i, o]) where
put (OR x) = put x
put _ = error "impossible"
get = OR <$> get
instance (SingI i, Layer x i o, Serialize x, Serialize (RecurrentNetwork xs (o ': r ': rs))) => Serialize (RecurrentNetwork (FeedForward x ': xs) (i ': o ': r ': rs)) where
put (x :~~> r) = put x >> put r
get = (:~~>) <$> get <*> get
instance (SingI i, RecurrentLayer x i o, Serialize x, Serialize (RecurrentNetwork xs (o ': r ': rs))) => Serialize (RecurrentNetwork (Recurrent x ': xs) (i ': o ': r ': rs)) where
put (x :~@> r) = put x >> put r
get = (:~@>) <$> get <*> get
instance (UpdateLayer x) => (Serialize (RecurrentInputs '[FeedForward x])) where
put _ = return ()
get = return (ORS ())
instance (UpdateLayer x, Serialize (RecurrentInputs (y ': ys))) => (Serialize (RecurrentInputs (FeedForward x ': y ': ys))) where
put ( () :~~+> rest) = put rest
get = ( () :~~+> ) <$> get
instance (SingI (RecurrentShape x), RecurrentUpdateLayer x, Serialize (RecurrentInputs (y ': ys))) => (Serialize (RecurrentInputs (Recurrent x ': y ': ys))) where
put ( i :~@+> rest ) = do
_ <- (case i of
(S1D x) -> putListOf put . LA.toList . LAS.extract $ x
(S2D x) -> putListOf put . LA.toList . LA.flatten . LAS.extract $ x
(S3D x) -> putListOf put . LA.toList . LA.flatten . LAS.extract $ x
) :: PutM ()
put rest
get = do
Just i <- fromStorable . V.fromList <$> getListOf get
rest <- get
return ( i :~@+> rest)

View File

@ -19,6 +19,12 @@ import Grenade.Core.Shape
import Grenade.Recurrent.Core.Network
-- | Drive and network and collect its back propogated gradients.
--
-- TODO: split this nicely into backpropagate and update.
--
-- QUESTION: Should we return a list of gradients or the sum of
-- the gradients? It's different taking into account
-- momentum and L2.
trainRecurrent :: forall shapes layers. SingI (Last shapes)
=> LearningParameters
-> RecurrentNetwork layers shapes
@ -94,7 +100,7 @@ trainRecurrent rate network recinputs examples =
= () :~~+> updateRecInputs l xs ys
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
= (x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
= (realToFrac (learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
updateRecInputs _ (ORS ()) (ORS ())
= ORS ()

View File

@ -8,6 +8,7 @@
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- GHC 7.10 doesn't see recurrent run functions as total.
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
@ -19,9 +20,13 @@ module Grenade.Recurrent.Layers.LSTM (
import Control.Monad.Random ( MonadRandom, getRandom )
-- import Data.List ( foldl1' )
import Data.Proxy
import Data.Serialize
import Data.Singletons.TypeLits
import qualified Numeric.LinearAlgebra as LA
import Numeric.LinearAlgebra.Static
import Grenade.Core.Network
@ -118,7 +123,6 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
-- v :: forall x ix. (x -> (R ix)) -> x -> x -> R ix
-- v e (e -> a) (e -> b) = a + b
createRandom = randomLSTM
instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
@ -245,3 +249,46 @@ sigmoid' x = logix * (1 - logix)
tanh' :: (Floating a) => a -> a
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
put (LSTM LSTMWeights {..} _) = do
u lstmWf
u lstmUf
v lstmBf
u lstmWi
u lstmUi
v lstmBi
u lstmWo
u lstmUo
v lstmBo
u lstmWc
v lstmBc
where
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
u = putListOf put . LA.toList . LA.flatten . extract
v :: forall a. (KnownNat a) => Putter (R a)
v = putListOf put . LA.toList . extract
get = do
lstmWf <- u
lstmUf <- u
lstmBf <- v
lstmWi <- u
lstmUi <- u
lstmBi <- v
lstmWo <- u
lstmUo <- u
lstmBo <- v
lstmWc <- u
lstmBc <- v
return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
where
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
v :: forall a. (KnownNat a) => Get (R a)
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
w0 = konst 0
u0 = konst 0
v0 = konst 0

View File

@ -7,12 +7,18 @@ module Grenade.Recurrent.Layers.Trivial (
Trivial (..)
) where
import Data.Serialize
import Grenade.Core.Network
-- | A trivial layer.
data Trivial = Trivial
deriving Show
instance Serialize Trivial where
put _ = return ()
get = return Trivial
instance UpdateLayer Trivial where
type Gradient Trivial = ()
runUpdate _ _ _ = Trivial

View File

@ -11,8 +11,11 @@ module Grenade.Utils.OneHot (
, hotMap
, makeHot
, unHot
, sample
) where
import qualified Control.Monad.Random as MR
import Data.List ( group, sort )
import Data.Map ( Map )
@ -23,6 +26,7 @@ import Data.Singletons.TypeLits
import Data.Vector ( Vector )
import qualified Data.Vector as V
import qualified Data.Vector.Storable as VS
import Numeric.LinearAlgebra ( maxIndex )
import Numeric.LinearAlgebra.Devel
@ -76,9 +80,14 @@ makeHot m x = do
return vec
else Nothing
unHot :: forall a n. (KnownNat n)
=> Vector a -> (S ('D1 n)) -> Maybe a
unHot :: forall a n. KnownNat n
=> Vector a -> S ('D1 n) -> Maybe a
unHot v (S1D xs)
= (V.!?) v
$ maxIndex (extract xs)
sample :: forall a n m. (KnownNat n, MR.MonadRandom m)
=> Double -> Vector a -> S ('D1 n) -> m a
sample temperature v (S1D xs) = do
ix <- MR.fromList . zip [0..] . fmap (toRational . exp . (/ temperature) . log) . VS.toList . extract $ xs
return $ v V.! ix

View File

@ -22,6 +22,7 @@ import Grenade.Layers.Convolution
import Disorder.Jack
import Test.Jack.Hmatrix
import Test.Jack.TypeLits
data OpaqueConvolution :: * where
OpaqueConvolution :: Convolution channels filters kernelRows kernelColumns strideRows strideColumns -> OpaqueConvolution
@ -42,12 +43,12 @@ genConvolution = Convolution <$> uniformSample <*> uniformSample
genOpaqueOpaqueConvolution :: Jack OpaqueConvolution
genOpaqueOpaqueConvolution = do
Just channels <- someNatVal <$> choose (1, 10)
Just filters <- someNatVal <$> choose (1, 10)
Just kernel_h <- someNatVal <$> choose (2, 20)
Just kernel_w <- someNatVal <$> choose (2, 20)
Just stride_h <- someNatVal <$> choose (1, 10)
Just stride_w <- someNatVal <$> choose (1, 10)
channels <- genNat
filters <- genNat
kernel_h <- genNat
kernel_w <- genNat
stride_h <- genNat
stride_w <- genNat
case (channels, filters, kernel_h, kernel_w, stride_h, stride_w) of
( SomeNat (pch :: Proxy ch), SomeNat (_ :: Proxy fl),
SomeNat (pkr :: Proxy kr), SomeNat (pkc :: Proxy kc),

View File

@ -90,7 +90,7 @@ prop_lstm_reference_backwards_cell =
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
refGradients = Reference.runLSTMbackOnCell refInput refNet refCell
in toList refGradients ~~~ (H.toList . S.extract $ actualGradients)
in toList refGradients ~~~ H.toList (S.extract actualGradients)
(~~~) as bs = all (< 1e-8) (zipWith (-) as bs)