mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-28 11:16:07 +03:00
Merge pull request #25 from HuwCampbell/topic/generative-adversarial
Topic/generative adversarial
This commit is contained in:
commit
c3f0373fbe
@ -130,6 +130,25 @@ executable mnist
|
|||||||
, MonadRandom
|
, MonadRandom
|
||||||
, vector
|
, vector
|
||||||
|
|
||||||
|
executable gan-mnist
|
||||||
|
ghc-options: -Wall -threaded -O2
|
||||||
|
main-is: main/gan-mnist.hs
|
||||||
|
build-depends: base
|
||||||
|
, grenade
|
||||||
|
, attoparsec
|
||||||
|
, bytestring
|
||||||
|
, cereal
|
||||||
|
, either
|
||||||
|
, optparse-applicative == 0.13.*
|
||||||
|
, text == 1.2.*
|
||||||
|
, mtl >= 2.2.1 && < 2.3
|
||||||
|
, hmatrix >= 0.18 && < 0.19
|
||||||
|
, transformers
|
||||||
|
, semigroups
|
||||||
|
, singletons
|
||||||
|
, MonadRandom
|
||||||
|
, vector
|
||||||
|
|
||||||
executable recurrent
|
executable recurrent
|
||||||
ghc-options: -Wall -threaded -O2
|
ghc-options: -Wall -threaded -O2
|
||||||
main-is: main/recurrent.hs
|
main-is: main/recurrent.hs
|
||||||
|
167
main/gan-mnist.hs
Normal file
167
main/gan-mnist.hs
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
{-# LANGUAGE BangPatterns #-}
|
||||||
|
{-# LANGUAGE DataKinds #-}
|
||||||
|
{-# LANGUAGE GADTs #-}
|
||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
{-# LANGUAGE TypeOperators #-}
|
||||||
|
{-# LANGUAGE TupleSections #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
|
|
||||||
|
-- This is a simple generative adversarial network to make pictures
|
||||||
|
-- of numbers similar to those in MNIST.
|
||||||
|
--
|
||||||
|
-- It demonstrates a different usage of the library, within a few hours
|
||||||
|
-- was producing examples like this:
|
||||||
|
--
|
||||||
|
-- --.
|
||||||
|
-- .=-.--..#=###
|
||||||
|
-- -##==#########.
|
||||||
|
-- #############-
|
||||||
|
-- -###-.=..-.-==
|
||||||
|
-- ###-
|
||||||
|
-- .###-
|
||||||
|
-- .####...==-.
|
||||||
|
-- -####=--.=##=
|
||||||
|
-- -##=- -##
|
||||||
|
-- =##
|
||||||
|
-- -##=
|
||||||
|
-- -###-
|
||||||
|
-- .####.
|
||||||
|
-- .#####.
|
||||||
|
-- ...---=#####-
|
||||||
|
-- .=#########. .
|
||||||
|
-- .#######=. .
|
||||||
|
-- . =-.
|
||||||
|
--
|
||||||
|
-- It's a 5!
|
||||||
|
--
|
||||||
|
import Control.Applicative
|
||||||
|
import Control.Monad
|
||||||
|
import Control.Monad.Random
|
||||||
|
import Control.Monad.Trans.Except
|
||||||
|
|
||||||
|
import qualified Data.Attoparsec.Text as A
|
||||||
|
import qualified Data.ByteString as B
|
||||||
|
import Data.List ( foldl' )
|
||||||
|
import Data.Semigroup ( (<>) )
|
||||||
|
import Data.Serialize
|
||||||
|
import qualified Data.Text as T
|
||||||
|
import qualified Data.Text.IO as T
|
||||||
|
import qualified Data.Vector.Storable as V
|
||||||
|
|
||||||
|
import qualified Numeric.LinearAlgebra.Static as SA
|
||||||
|
import Numeric.LinearAlgebra.Data ( toLists )
|
||||||
|
|
||||||
|
import Options.Applicative
|
||||||
|
|
||||||
|
import Grenade
|
||||||
|
import Grenade.Utils.OneHot
|
||||||
|
|
||||||
|
type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit]
|
||||||
|
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
|
||||||
|
|
||||||
|
type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape]
|
||||||
|
'[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ]
|
||||||
|
|
||||||
|
randomDiscriminator :: MonadRandom m => m Discriminator
|
||||||
|
randomDiscriminator = randomNetwork
|
||||||
|
|
||||||
|
randomGenerator :: MonadRandom m => m Generator
|
||||||
|
randomGenerator = randomNetwork
|
||||||
|
|
||||||
|
trainExample :: LearningParameters -> Discriminator -> Generator -> S ('D2 28 28) -> S ('D1 100) -> ( Discriminator, Generator )
|
||||||
|
trainExample rate discriminator generator realExample noiseSource
|
||||||
|
= let (generatorTape, fakeExample) = runNetwork generator noiseSource
|
||||||
|
|
||||||
|
(discriminatorTapeReal, guessReal) = runNetwork discriminator realExample
|
||||||
|
(discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
|
||||||
|
|
||||||
|
(discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
|
||||||
|
(discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake
|
||||||
|
|
||||||
|
(generator', _) = runGradient generator generatorTape (-push)
|
||||||
|
|
||||||
|
newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
|
||||||
|
newGenerator = applyUpdate rate generator generator'
|
||||||
|
in ( newDiscriminator, newGenerator )
|
||||||
|
|
||||||
|
|
||||||
|
ganTest :: (Discriminator, Generator) -> Int -> FilePath -> LearningParameters -> ExceptT String IO (Discriminator, Generator)
|
||||||
|
ganTest (discriminator0, generator0) iterations trainFile rate = do
|
||||||
|
trainData <- fmap fst <$> readMNIST trainFile
|
||||||
|
|
||||||
|
lift $ foldM (runIteration trainData) ( discriminator0, generator0 ) [1..iterations]
|
||||||
|
|
||||||
|
where
|
||||||
|
|
||||||
|
showShape' :: S ('D2 a b) -> IO ()
|
||||||
|
showShape' (S2D mm) = putStrLn $
|
||||||
|
let m = SA.extract mm
|
||||||
|
ms = toLists m
|
||||||
|
render n' | n' <= 0.2 = ' '
|
||||||
|
| n' <= 0.4 = '.'
|
||||||
|
| n' <= 0.6 = '-'
|
||||||
|
| n' <= 0.8 = '='
|
||||||
|
| otherwise = '#'
|
||||||
|
|
||||||
|
px = (fmap . fmap) render ms
|
||||||
|
in unlines px
|
||||||
|
|
||||||
|
runIteration :: [S ('D2 28 28)] -> (Discriminator, Generator) -> Int -> IO (Discriminator, Generator)
|
||||||
|
runIteration trainData ( !discriminator, !generator ) _ = do
|
||||||
|
trained' <- foldM ( \(!discriminatorX, !generatorX ) realExample -> do
|
||||||
|
fakeExample <- randomOfShape
|
||||||
|
return $ trainExample rate discriminatorX generatorX realExample fakeExample
|
||||||
|
) ( discriminator, generator ) trainData
|
||||||
|
|
||||||
|
|
||||||
|
showShape' . snd . runNetwork (snd trained') =<< randomOfShape
|
||||||
|
|
||||||
|
return trained'
|
||||||
|
|
||||||
|
data GanOpts = GanOpts FilePath Int LearningParameters (Maybe FilePath) (Maybe FilePath)
|
||||||
|
|
||||||
|
mnist' :: Parser GanOpts
|
||||||
|
mnist' = GanOpts <$> argument str (metavar "TRAIN")
|
||||||
|
<*> option auto (long "iterations" <> short 'i' <> value 15)
|
||||||
|
<*> (LearningParameters
|
||||||
|
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||||
|
<*> option auto (long "momentum" <> value 0.9)
|
||||||
|
<*> option auto (long "l2" <> value 0.0005)
|
||||||
|
)
|
||||||
|
<*> optional (strOption (long "load"))
|
||||||
|
<*> optional (strOption (long "save"))
|
||||||
|
|
||||||
|
|
||||||
|
main :: IO ()
|
||||||
|
main = do
|
||||||
|
GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
|
||||||
|
putStrLn "Training stupidly simply GAN"
|
||||||
|
nets0 <- case load of
|
||||||
|
Just loadFile -> netLoad loadFile
|
||||||
|
Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
|
||||||
|
|
||||||
|
res <- runExceptT $ ganTest nets0 iter mnist rate
|
||||||
|
case res of
|
||||||
|
Right nets1 -> case save of
|
||||||
|
Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
|
||||||
|
Nothing -> return ()
|
||||||
|
|
||||||
|
Left err -> putStrLn err
|
||||||
|
|
||||||
|
readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
|
||||||
|
readMNIST mnist = ExceptT $ do
|
||||||
|
mnistdata <- T.readFile mnist
|
||||||
|
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
|
||||||
|
|
||||||
|
parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
|
||||||
|
parseMNIST = do
|
||||||
|
Just lab <- oneHot <$> A.decimal
|
||||||
|
pixels <- many (A.char ',' >> A.double)
|
||||||
|
image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
|
||||||
|
return (image, lab)
|
||||||
|
|
||||||
|
netLoad :: FilePath -> IO (Discriminator, Generator)
|
||||||
|
netLoad modelPath = do
|
||||||
|
modelData <- B.readFile modelPath
|
||||||
|
either fail return $ runGet (get :: Get (Discriminator, Generator)) modelData
|
@ -21,7 +21,10 @@ import Grenade.Core
|
|||||||
--
|
--
|
||||||
-- Flattens input down to D1 from either 2D or 3D data.
|
-- Flattens input down to D1 from either 2D or 3D data.
|
||||||
--
|
--
|
||||||
-- Can also be used to turn a 3D image with only one channel into a 2D image.
|
-- Casts input D1 up to either 2D or 3D data if the shapes are good.
|
||||||
|
--
|
||||||
|
-- Can also be used to turn a 3D image with only one channel into a 2D image
|
||||||
|
-- or vice versa.
|
||||||
data Reshape = Reshape
|
data Reshape = Reshape
|
||||||
deriving Show
|
deriving Show
|
||||||
|
|
||||||
@ -50,6 +53,16 @@ instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer Reshape ('D2 x y)
|
|||||||
runForwards _ (S2D y) = ((), S3D y)
|
runForwards _ (S2D y) = ((), S3D y)
|
||||||
runBackwards _ _ (S3D y) = ((), S2D y)
|
runBackwards _ _ (S3D y) = ((), S2D y)
|
||||||
|
|
||||||
|
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer Reshape ('D1 a) ('D2 x y) where
|
||||||
|
type Tape Reshape ('D1 a) ('D2 x y) = ()
|
||||||
|
runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
|
||||||
|
runBackwards _ _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y)
|
||||||
|
|
||||||
|
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer Reshape ('D1 a) ('D3 x y z) where
|
||||||
|
type Tape Reshape ('D1 a) ('D3 x y z) = ()
|
||||||
|
runForwards _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
|
||||||
|
runBackwards _ _ (S3D y) = ((), fromJust' . fromStorable . flatten . extract $ y)
|
||||||
|
|
||||||
instance Serialize Reshape where
|
instance Serialize Reshape where
|
||||||
put _ = return ()
|
put _ = return ()
|
||||||
get = return Reshape
|
get = return Reshape
|
||||||
|
Loading…
Reference in New Issue
Block a user