mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Add a generative adversarial network
This commit is contained in:
parent
60e28aad8e
commit
6db2c40646
@ -135,19 +135,19 @@ mnist' = GanOpts <$> argument str (metavar "TRAIN")
|
||||
|
||||
main :: IO ()
|
||||
main = do
|
||||
GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
|
||||
putStrLn "Training stupidly simply GAN"
|
||||
nets0 <- case load of
|
||||
Just loadFile -> netLoad loadFile
|
||||
Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
|
||||
GanOpts mnist iter rate load save <- execParser (info (mnist' <**> helper) idm)
|
||||
putStrLn "Training stupidly simply GAN"
|
||||
nets0 <- case load of
|
||||
Just loadFile -> netLoad loadFile
|
||||
Nothing -> (,) <$> randomDiscriminator <*> randomGenerator
|
||||
|
||||
res <- runExceptT $ ganTest nets0 iter mnist rate
|
||||
case res of
|
||||
Right nets1 -> case save of
|
||||
Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
|
||||
Nothing -> return ()
|
||||
res <- runExceptT $ ganTest nets0 iter mnist rate
|
||||
case res of
|
||||
Right nets1 -> case save of
|
||||
Just saveFile -> B.writeFile saveFile $ runPut (put nets1)
|
||||
Nothing -> return ()
|
||||
|
||||
Left err -> putStrLn err
|
||||
Left err -> putStrLn err
|
||||
|
||||
readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
|
||||
readMNIST mnist = ExceptT $ do
|
||||
|
Loading…
Reference in New Issue
Block a user