mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Nicer types, and better regulariser so mnist actually works
This commit is contained in:
parent
70fbac3924
commit
109d5be3ca
@ -26,7 +26,7 @@ import Grenade
|
||||
-- between the shapes, so inference can't do it all for us.
|
||||
|
||||
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
|
||||
randomNet :: (MonadRandom m) => m (Network '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1])
|
||||
randomNet :: MonadRandom m => m (Network '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1])
|
||||
randomNet = do
|
||||
a :: FullyConnected 2 40 <- randomFullyConnected
|
||||
b :: FullyConnected 40 10 <- randomFullyConnected
|
||||
@ -74,7 +74,7 @@ feedForward' =
|
||||
<*> (LearningParameters
|
||||
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||
<*> option auto (long "momentum" <> value 0.9)
|
||||
<*> option auto (long "l2" <> value 0.0001)
|
||||
<*> option auto (long "l2" <> value 0.0005)
|
||||
)
|
||||
|
||||
main :: IO ()
|
||||
|
@ -74,11 +74,11 @@ data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
|
||||
mnist' :: Parser MnistOpts
|
||||
mnist' = MnistOpts <$> argument str (metavar "TRAIN")
|
||||
<*> argument str (metavar "VALIDATE")
|
||||
<*> option auto (long "iterations" <> short 'i' <> value 15)
|
||||
<*> option auto (long "iterations" <> short 'i' <> value 10)
|
||||
<*> (LearningParameters
|
||||
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||
<*> option auto (long "momentum" <> value 0.9)
|
||||
<*> option auto (long "l2" <> value 0.0001)
|
||||
<*> option auto (long "l2" <> value 0.0005)
|
||||
)
|
||||
|
||||
main :: IO ()
|
||||
|
@ -31,7 +31,7 @@ class UpdateLayer x where
|
||||
-- Unit if there isn't a gradient to pass back.
|
||||
type Gradient x :: *
|
||||
-- | Update a layer with its gradient and learning parameters
|
||||
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
||||
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
||||
|
||||
-- | Class for a layer. All layers implement this, however, they don't
|
||||
-- need to implement it for all shapes, only ones which are appropriate.
|
||||
|
Loading…
Reference in New Issue
Block a user