Add a generative adversarial network

This commit is contained in:
Huw Campbell 2017-02-08 18:14:13 +11:00
parent 60e28aad8e
commit 6db2c40646

View File

@ -135,19 +135,19 @@ mnist' = GanOpts <$> argument str (metavar "TRAIN")
main :: IO ()
main = do
GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
putStrLn "Training stupidly simply GAN"
nets0 <- case load of
Just loadFile -> netLoad loadFile
Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
putStrLn "Training stupidly simply GAN"
nets0 <- case load of
Just loadFile -> netLoad loadFile
Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
res <- runExceptT $ ganTest nets0 iter mnist rate
case res of
Right nets1 -> case save of
Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
Nothing -> return ()
res <- runExceptT $ ganTest nets0 iter mnist rate
case res of
Right nets1 -> case save of
Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
Nothing -> return ()
Left err -> putStrLn err
Left err -> putStrLn err
readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
readMNIST mnist = ExceptT $ do