Add very basic model saving and loading

This commit is contained in:
Huw Campbell 2016-12-10 09:31:01 +11:00
parent 20e7e483d7
commit 4dc408f39d
13 changed files with 131 additions and 22 deletions

View File

@ -89,7 +89,6 @@ and the tests run using:
Grenade is currently known to build with ghc 7.10 and 8.0.
Thanks
------
Writing a library like this has been on my mind for a while now, but a big shout
@ -108,8 +107,13 @@ Being purely functional, it should also be easy to run batches in parallel, whic
would be appropriate for larger networks, my current examples however are single
threaded.
<<<<<<< 20e7e483d75613f16580baa71b44fa9864c940fd
Training 15 generations over Kaggle's 41000 sample MNIST training set on a single
core took around 12 minutes, achieving 1.5% error rate on a 1000 sample holdout set.
=======
Training 15 generations over Kaggle's 42000 sample MNIST training set took under
an hour on my laptop, achieving 0.5% error rate on a 1000 sample holdout set.
>>>>>>> Add very basic model saving and loading
Contributing
------------

View File

@ -22,6 +22,7 @@ library
, containers
, deepseq
, either == 4.4.*
, cereal
, exceptions == 0.8.*
, hmatrix
, MonadRandom
@ -85,6 +86,8 @@ executable feedforward
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.12.*
, text == 1.2.*

View File

@ -8,6 +8,9 @@ import Control.Monad
import Control.Monad.Random
import Data.List ( foldl' )
import qualified Data.ByteString as B
import Data.Serialize
import GHC.TypeLits
import qualified Numeric.LinearAlgebra.Static as SA
@ -16,13 +19,12 @@ import Options.Applicative
import Grenade
-- The defininition for our simple feed forward network.
-- The type level list represents the shapes passed through the layers. One can see that for this demonstration
-- we are using relu, tanh and logit non-linear units, which can be easily subsituted for each other in and out.
-- It's important to keep the type signatures, as there's many layers which can "squeeze" into the gaps
-- between the shapes, so inference can't do it all for us.
-- The type level lists represents the layers and the shapes passed through the layers.
-- One can see that for this demonstration we are using relu, tanh and logit non-linear
-- units, which can be easily subsituted for each other in and out.
--
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ]
'[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1]
@ -30,8 +32,8 @@ type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, F
randomNet :: MonadRandom m => m FFNet
randomNet = randomNetwork
netTest :: MonadRandom m => LearningParameters -> Int -> m String
netTest rate n = do
netTrain :: FFNet -> LearningParameters -> Int -> IO FFNet
netTrain net0 rate n = do
inps <- replicateM n $ do
s <- getRandom
return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
@ -39,20 +41,28 @@ netTest rate n = do
if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
then S1D $ fromRational 1
else S1D $ fromRational 0
net0 <- randomNet
let trained = foldl' trainEach net0 (zip inps outs)
let testIns = [ [ (x,y) | x <- [0..50] ]
| y <- [0..20] ]
let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
return $ unlines outMat
return trained
where
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
trainEach !network (i,o) = train rate network i o
netLoad :: FilePath -> IO FFNet
netLoad modelPath = do
modelData <- B.readFile modelPath
either fail return $ runGet (get :: Get FFNet) modelData
netScore :: FFNet -> IO ()
netScore network = do
let testIns = [ [ (x,y) | x <- [0..50] ]
| y <- [0..20] ]
outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet network (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
putStrLn $ unlines outMat
where
render n' | n' <= 0.2 = ' '
| n' <= 0.4 = '.'
| n' <= 0.6 = '-'
@ -62,7 +72,7 @@ netTest rate n = do
normx :: S ('D1 1) -> Double
normx (S1D r) = SA.mean r
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
data FeedForwardOpts = FeedForwardOpts Int LearningParameters (Maybe FilePath) (Maybe FilePath)
feedForward' :: Parser FeedForwardOpts
feedForward' =
@ -72,9 +82,19 @@ feedForward' =
<*> option auto (long "momentum" <> value 0.9)
<*> option auto (long "l2" <> value 0.0005)
)
<*> optional (strOption (long "load"))
<*> optional (strOption (long "save"))
main :: IO ()
main = do
FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm)
putStrLn "Training network..."
putStrLn =<< evalRandIO (netTest rate examples)
FeedForwardOpts examples rate load save <- execParser (info (feedForward' <**> helper) idm)
net0 <- case load of
Just loadFile -> netLoad loadFile
Nothing -> randomNet
net <- netTrain net0 rate examples
netScore net
case save of
Just saveFile -> B.writeFile saveFile $ runPut (put net)
Nothing -> return ()

View File

@ -50,11 +50,11 @@ type F = FeedForward
type R = Recurrent
-- The definition of our network
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
'[ 'D1 40, 'D1 40, 'D1 40, 'D1 40 ]
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 50), R (LSTM 50 40), F (FullyConnected 40 40), F Logit]
'[ 'D1 40, 'D1 50, 'D1 40, 'D1 40, 'D1 40 ]
-- The definition of the "sideways" input, which the network if fed recurrently.
type Shakespearian = RecurrentInputs '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
type Shakespearian = RecurrentInputs '[ R (LSTM 40 50), R (LSTM 50 40), F (FullyConnected 40 40), F Logit]
randomNet :: MonadRandom m => m (Shakespeare, Shakespearian)
randomNet = randomRecurrent

View File

@ -29,6 +29,8 @@ import Control.Monad.Random (MonadRandom)
import Data.List ( foldl' )
import Data.Singletons
import Data.Serialize
import Grenade.Core.Shape
-- | Learning parameters for stochastic gradient descent.
@ -101,3 +103,14 @@ instance (SingI i, SingI o, Layer x i o) => CreatableNetwork (x ': '[]) (i ': o
instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
randomNetwork = (:~>) <$> createRandom <*> randomNetwork
-- | Add very simple serialisation to the network
instance (SingI i, SingI o, Layer x i o, Serialize x) => Serialize (Network '[x] '[i, o]) where
put (O x) = put x
put _ = error "impossible"
get = O <$> get
instance (SingI i, SingI o, Layer x i o, Serialize x, Serialize (Network xs (o ': r ': rs))) => Serialize (Network (x ': xs) (i ': o ': r ': rs)) where
put (x :~> r) = put x >> put r
get = (:~>) <$> get <*> get

View File

@ -16,7 +16,9 @@ module Grenade.Layers.Convolution (
import Control.Monad.Random hiding ( fromList )
import Data.Maybe
import Data.Proxy
import Data.Serialize
import Data.Singletons.TypeLits
import GHC.TypeLits
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
@ -125,6 +127,21 @@ instance ( KnownNat channels
createRandom = randomConvolution
instance ( KnownNat channels
, KnownNat filters
, KnownNat kernelRows
, KnownNat kernelColumns
, KnownNat strideRows
, KnownNat strideColumns
, KnownNat (kernelRows * kernelColumns * channels)
) => Serialize (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
put (Convolution w _) = putListOf put . toList . flatten . extract $ w
get = do
let f = fromIntegral $ natVal (Proxy :: Proxy filters)
wN <- maybe (fail "Vector of incorrect size") return . create . reshape f . LA.fromList =<< getListOf get
let mm = konst 0
return $ Convolution wN mm
-- | A two dimentional image may have a convolution filter applied to it
instance ( KnownNat kernelRows
, KnownNat kernelCols

View File

@ -7,6 +7,8 @@ module Grenade.Layers.Flatten (
FlattenLayer (..)
) where
import Data.Serialize
import Data.Singletons.TypeLits
import GHC.TypeLits
@ -41,6 +43,11 @@ instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer FlattenLayer ('D3
runForwards _ (S3D y) = S2D y
runBackwards _ _ (S2D y) = ((), S3D y)
instance Serialize FlattenLayer where
put _ = return ()
get = return FlattenLayer
fromJust' :: Maybe x -> x
fromJust' (Just x) = x
fromJust' Nothing = error $ "FlattenLayer error: data shape couldn't be converted."

View File

@ -3,6 +3,7 @@
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Grenade.Layers.FullyConnected (
FullyConnected (..)
, randomFullyConnected
@ -10,8 +11,11 @@ module Grenade.Layers.FullyConnected (
import Control.Monad.Random hiding (fromList)
import Data.Proxy
import Data.Serialize
import Data.Singletons.TypeLits
import qualified Numeric.LinearAlgebra as LA
import Numeric.LinearAlgebra.Static
import Grenade.Core.Network
@ -55,6 +59,19 @@ instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o)
dWs = tr wN #> dEdy
in (FullyConnected' wB' mm', S1D dWs)
instance (KnownNat i, KnownNat o) => Serialize (FullyConnected i o) where
put (FullyConnected b _ w _) = do
putListOf put . LA.toList . extract $ b
putListOf put . LA.toList . LA.flatten . extract $ w
get = do
let f = fromIntegral $ natVal (Proxy :: Proxy i)
b <- maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
k <- maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
let bm = konst 0
let mm = konst 0
return $ FullyConnected b bm k mm
randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o)
=> m (FullyConnected i o)
randomFullyConnected = do

View File

@ -7,6 +7,8 @@ module Grenade.Layers.Logit (
) where
import Data.Serialize
import Data.Singletons.TypeLits
import Grenade.Core.Network
import Grenade.Core.Shape
@ -35,6 +37,10 @@ instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i
runForwards _ (S3D y) = S3D (logistic y)
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (logistic' y * dEdy))
instance Serialize Logit where
put _ = return ()
get = return Logit
logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (-x))

View File

@ -10,6 +10,7 @@ module Grenade.Layers.Pad (
import Data.Maybe
import Data.Proxy
import Data.Serialize
import Data.Singletons.TypeLits
import GHC.TypeLits
@ -34,6 +35,10 @@ instance UpdateLayer (Pad l t r b) where
runUpdate _ x _ = x
createRandom = return Pad
instance Serialize (Pad l t r b) where
put _ = return ()
get = return Pad
-- | A two dimentional image can be padped.
instance ( KnownNat padLeft
, KnownNat padTop

View File

@ -12,6 +12,7 @@ module Grenade.Layers.Pooling (
import Data.Maybe
import Data.Proxy
import Data.Serialize
import Data.Singletons.TypeLits
import GHC.TypeLits
@ -39,6 +40,10 @@ instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns)
runUpdate _ Pooling _ = Pooling
createRandom = return Pooling
instance Serialize (Pooling kernelRows kernelColumns strideRows strideColumns) where
put _ = return ()
get = return Pooling
-- | A two dimentional image can be pooled.
instance ( KnownNat kernelRows
, KnownNat kernelColumns

View File

@ -6,6 +6,8 @@ module Grenade.Layers.Relu (
Relu (..)
) where
import Data.Serialize
import GHC.TypeLits
import Grenade.Core.Network
import Grenade.Core.Shape
@ -23,6 +25,10 @@ instance UpdateLayer Relu where
runUpdate _ _ _ = Relu
createRandom = return Relu
instance Serialize Relu where
put _ = return ()
get = return Relu
instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
runForwards _ (S1D y) = S1D (relu y)
where

View File

@ -6,6 +6,8 @@ module Grenade.Layers.Tanh (
Tanh (..)
) where
import Data.Serialize
import GHC.TypeLits
import Grenade.Core.Network
import Grenade.Core.Shape
@ -20,6 +22,10 @@ instance UpdateLayer Tanh where
runUpdate _ _ _ = Tanh
createRandom = return Tanh
instance Serialize Tanh where
put _ = return ()
get = return Tanh
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
runForwards _ (S1D y) = S1D (tanh y)
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (tanh' y * dEdy))