grenade/main/gan-mnist.hs
2017-02-08 18:14:13 +11:00

168 lines
6.4 KiB
Haskell

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
-- This is a simple generative adversarial network to make pictures
-- of numbers similar to those in MNIST.
--
-- It demonstrates a different usage of the library, within a few hours
-- was producing examples like this:
--
-- --.
-- .=-.--..#=###
-- -##==#########.
-- #############-
-- -###-.=..-.-==
-- ###-
-- .###-
-- .####...==-.
-- -####=--.=##=
-- -##=- -##
-- =##
-- -##=
-- -###-
-- .####.
-- .#####.
-- ...---=#####-
-- .=#########. .
-- .#######=. .
-- . =-.
--
-- It's a 5!
--
import Control.Applicative
import Control.Monad
import Control.Monad.Random
import Control.Monad.Trans.Except
import qualified Data.Attoparsec.Text as A
import qualified Data.ByteString as B
import Data.List ( foldl' )
import Data.Semigroup ( (<>) )
import Data.Serialize
import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified Data.Vector.Storable as V
import qualified Numeric.LinearAlgebra.Static as SA
import Numeric.LinearAlgebra.Data ( toLists )
import Options.Applicative
import Grenade
import Grenade.Utils.OneHot
type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape]
'[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ]
randomDiscriminator :: MonadRandom m => m Discriminator
randomDiscriminator = randomNetwork
randomGenerator :: MonadRandom m => m Generator
randomGenerator = randomNetwork
trainExample :: LearningParameters -> Discriminator -> Generator -> S ('D2 28 28) -> S ('D1 100) -> ( Discriminator, Generator )
trainExample rate discriminator generator realExample noiseSource
= let (generatorTape, fakeExample) = runNetwork generator noiseSource
(discriminatorTapeReal, guessReal) = runNetwork discriminator realExample
(discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
(discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
(discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake
(generator', _) = runGradient generator generatorTape (-push)
newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
newGenerator = applyUpdate rate generator generator'
in ( newDiscriminator, newGenerator )
ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator)
ganTest (discriminator0, generator0) iterations trainFile rate = do
trainData <- fmap fst <$> readMNIST trainFile
lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations]
where
showShape' :: S ('D2 a b) -> IO ()
showShape' (S2D mm) = putStrLn $
let m = SA.extract mm
ms = toLists m
render n' | n' <= 0.2 = ' '
| n' <= 0.4 = '.'
| n' <= 0.6 = '-'
| n' <= 0.8 = '='
| otherwise = '#'
px = (fmap . fmap) render ms
in unlines px
runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator)
runIteration trainData ( !discriminator, !generator ) _ = do
trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do
fakeExample <- randomOfShape
return $ trainExample rate discriminatorX generatorX realExample fakeExample
) ( discriminator, generator ) trainData
showShape' . snd . runNetwork (snd trained') =<< randomOfShape
return trained'
data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath)
mnist' :: Parser GanOpts
mnist' = GanOpts <$> argument str (metavar "TRAIN")
<*> option auto (long "iterations" <> short 'i' <> value 15)
<*> (LearningParameters
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
<*> option auto (long "momentum" <> value 0.9)
<*> option auto (long "l2" <> value 0.0005)
)
<*> optional (strOption (long "load"))
<*> optional (strOption (long "save"))
main :: IO ()
main = do
GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
putStrLn "Training stupidly simply GAN"
nets0 <- case load of
Just loadFile -> netLoad loadFile
Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
res <- runExceptT $ ganTest nets0 iter mnist rate
case res of
Right nets1 -> case save of
Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
Nothing -> return ()
Left err -> putStrLn err
readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
readMNIST mnist = ExceptT $ do
mnistdata <- T.readFile mnist
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
parseMNIST = do
Just lab <- oneHot <$> A.decimal
pixels <- many (A.char ',' >> A.double)
image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
return (image, lab)
netLoad :: FilePath -> IO (Discriminator, Generator)
netLoad modelPath = do
modelData <- B.readFile modelPath
either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData