From 60e28aad8e9e02563281586e9a1f152889ff7beb Mon Sep 17 00:00:00 2001 From: Huw Campbell Date: Wed, 8 Feb 2017 18:03:38 +1100 Subject: [PATCH 1/2] Add more reshape instances --- grenade.cabal | 19 ++++ main/gan-mnist.hs | 167 ++++++++++++++++++++++++++++++++++ src/Grenade/Layers/Reshape.hs | 15 ++- 3 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 main/gan-mnist.hs diff --git a/grenade.cabal b/grenade.cabal index c5d7293..a414941 100644 --- a/grenade.cabal +++ b/grenade.cabal @@ -130,6 +130,25 @@ executable mnist , MonadRandom , vector +executable gan-mnist + ghc-options: -Wall -threaded -O2 + main-is: main/gan-mnist.hs + build-depends: base + , grenade + , attoparsec + , bytestring + , cereal + , either + , optparse-applicative == 0.13.* + , text == 1.2.* + , mtl >= 2.2.1 && < 2.3 + , hmatrix >= 0.18 && < 0.19 + , transformers + , semigroups + , singletons + , MonadRandom + , vector + executable recurrent ghc-options: -Wall -threaded -O2 main-is: main/recurrent.hs diff --git a/main/gan-mnist.hs b/main/gan-mnist.hs new file mode 100644 index 0000000..0099b6f --- /dev/null +++ b/main/gan-mnist.hs @@ -0,0 +1,167 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE FlexibleContexts #-} + +-- This is a simple generative adversarial network to make pictures +-- of numbers similar to those in MNIST. +-- +-- It demonstrates a different usage of the library, within a few hours +-- was producing examples like this: +-- +-- --. +-- .=-.--..#=### +-- -##==#########. +-- #############- +-- -###-.=..-.-== +-- ###- +-- .###- +-- .####...==-. +-- -####=--.=##= +-- -##=- -## +-- =## +-- -##= +-- -###- +-- .####. +-- .#####. +-- ...---=#####- +-- .=#########. . +-- .#######=. . +-- . =-. +-- +-- It's a 5! +-- +import Control.Applicative +import Control.Monad +import Control.Monad.Random +import Control.Monad.Trans.Except + +import qualified Data.Attoparsec.Text as A +import qualified Data.ByteString as B +import Data.List ( foldl' ) +import Data.Semigroup ( (<>) ) +import Data.Serialize +import qualified Data.Text as T +import qualified Data.Text.IO as T +import qualified Data.Vector.Storable as V + +import qualified Numeric.LinearAlgebra.Static as SA +import Numeric.LinearAlgebra.Data ( toLists ) + +import Options.Applicative + +import Grenade +import Grenade.Utils.OneHot + +type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit] + '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1] + +type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape] + '[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ] + +randomDiscriminator :: MonadRandom m => m Discriminator +randomDiscriminator = randomNetwork + +randomGenerator :: MonadRandom m => m Generator +randomGenerator = randomNetwork + +trainExample :: LearningParameters -> Discriminator -> Generator -> S ('D2 28 28) -> S ('D1 100) -> ( Discriminator, Generator ) +trainExample rate discriminator generator realExample noiseSource + = let (generatorTape, fakeExample) = runNetwork generator noiseSource + + (discriminatorTapeReal, guessReal) = runNetwork discriminator realExample + (discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample + + (discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 ) + (discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake + + (generator', _) = runGradient generator generatorTape (-push) + + newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ] + newGenerator = applyUpdate rate generator generator' + in ( newDiscriminator, newGenerator ) + + +ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator) +ganTest (discriminator0, generator0) iterations trainFile rate = do + trainData <- fmap fst <$> readMNIST trainFile + + lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations] + + where + + showShape' :: S ('D2 a b) -> IO () + showShape' (S2D mm) = putStrLn $ + let m = SA.extract mm + ms = toLists m + render n' | n' <= 0.2 = ' ' + | n' <= 0.4 = '.' + | n' <= 0.6 = '-' + | n' <= 0.8 = '=' + | otherwise = '#' + + px = (fmap . fmap) render ms + in unlines px + + runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator) + runIteration trainData ( !discriminator, !generator ) _ = do + trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do + fakeExample <- randomOfShape + return $ trainExample rate discriminatorX generatorX realExample fakeExample + ) ( discriminator, generator ) trainData + + + showShape' . snd . runNetwork (snd trained') =<< randomOfShape + + return trained' + +data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath) + +mnist' :: Parser GanOpts +mnist' = GanOpts <$> argument str (metavar "TRAIN") + <*> option auto (long "iterations" <> short 'i' <> value 15) + <*> (LearningParameters + <$> option auto (long "train_rate" <> short 'r' <> value 0.01) + <*> option auto (long "momentum" <> value 0.9) + <*> option auto (long "l2" <> value 0.0005) + ) + <*> optional (strOption (long "load")) + <*> optional (strOption (long "save")) + + +main :: IO () +main = do + GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm) + putStrLn "Training stupidly simply GAN" + nets0 <- case load of + Just loadFile -> netLoad loadFile + Nothing -> (,) <$> randomDiscriminator <*> randomGenerator + + res <- runExceptT $ ganTest nets0 iter mnist rate + case res of + Right nets1 -> case save of + Just saveFile -> B.writeFile saveFile $ runPut (put nets1) + Nothing -> return () + + Left err -> putStrLn err + +readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))] +readMNIST mnist = ExceptT $ do + mnistdata <- T.readFile mnist + return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata) + +parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10)) +parseMNIST = do + Just lab <- oneHot <$> A.decimal + pixels <- many (A.char ',' >> A.double) + image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels) + return (image, lab) + +netLoad :: FilePath -> IO (Discriminator, Generator) +netLoad modelPath = do + modelData <- B.readFile modelPath + either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData diff --git a/src/Grenade/Layers/Reshape.hs b/src/Grenade/Layers/Reshape.hs index 2f7e5ca..0793341 100644 --- a/src/Grenade/Layers/Reshape.hs +++ b/src/Grenade/Layers/Reshape.hs @@ -21,7 +21,10 @@ import Grenade.Core -- -- Flattens input down to D1 from either 2D or 3D data. -- --- Can also be used to turn a 3D image with only one channel into a 2D image. +-- Casts input D1 up to either 2D or 3D data if the shapes are good. +-- +-- Can also be used to turn a 3D image with only one channel into a 2D image +-- or vice versa. data Reshape = Reshape deriving Show @@ -50,6 +53,16 @@ instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer Reshape ('D2 x y) runForwards _ (S2D y) = ((), S3D y) runBackwards _ _ (S3D y) = ((), S2D y) +instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D1 a) ('D2 x y) where + type Tape Reshape ('D1 a) ('D2 x y) = () + runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) + runBackwards _ _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y) + +instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D1 a) ('D3 x y z) where + type Tape Reshape ('D1 a) ('D3 x y z) = () + runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) + runBackwards _ _ (S3D y) = ((), fromJust' . fromStorable . flatten . extract $ y) + instance Serialize Reshape where put _ = return () get = return Reshape From 6db2c4064648db4d4b7bb04440d288a3d16ce93c Mon Sep 17 00:00:00 2001 From: Huw Campbell Date: Wed, 8 Feb 2017 18:14:13 +1100 Subject: [PATCH 2/2] Add a generative adversarial network --- main/gan-mnist.hs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/main/gan-mnist.hs b/main/gan-mnist.hs index 0099b6f..a0dd392 100644 --- a/main/gan-mnist.hs +++ b/main/gan-mnist.hs @@ -135,19 +135,19 @@ mnist' = GanOpts <$> argument str (metavar "TRAIN") main :: IO () main = do - GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm) - putStrLn "Training stupidly simply GAN" - nets0 <- case load of - Just loadFile -> netLoad loadFile - Nothing -> (,) <$> randomDiscriminator <*> randomGenerator + GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm) + putStrLn "Training stupidly simply GAN" + nets0 <- case load of + Just loadFile -> netLoad loadFile + Nothing -> (,) <$> randomDiscriminator <*> randomGenerator - res <- runExceptT $ ganTest nets0 iter mnist rate - case res of - Right nets1 -> case save of - Just saveFile -> B.writeFile saveFile $ runPut (put nets1) - Nothing -> return () + res <- runExceptT $ ganTest nets0 iter mnist rate + case res of + Right nets1 -> case save of + Just saveFile -> B.writeFile saveFile $ runPut (put nets1) + Nothing -> return () - Left err -> putStrLn err + Left err -> putStrLn err readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))] readMNIST mnist = ExceptT $ do