mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Remove pad from MNIST example, tidy ups
This commit is contained in:
parent
ca4b0fe912
commit
dfbb6c17b8
@ -13,10 +13,10 @@ Five is right out.
|
||||
Grenade is a dependently typed, practical, and pretty quick neural network library for concise and precise
|
||||
specifications of complex networks in Haskell.
|
||||
|
||||
As an example, a network which can achieve less than 1.5% error on mnist can be specified and
|
||||
As an example, a network which can achieve less than 1.5% error on MNIST can be specified and
|
||||
initialised with random weights in under 10 lines of code with
|
||||
```haskell
|
||||
randomMnistNet :: MonadRandom m => m (Network Identity '[('D2 28 28), ('D3 24 24 10), ('D3 12 12 10), ('D3 12 12 10), ('D3 8 8 16), ('D3 4 4 16), ('D1 256), ('D1 256), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
|
||||
randomMnistNet :: MonadRandom m => m (Network '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
|
||||
randomMnistNet = do
|
||||
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
|
||||
let b :: Pooling 2 2 2 2 = Pooling
|
||||
|
@ -26,7 +26,7 @@ import Grenade
|
||||
-- between the shapes, so inference can't do it all for us.
|
||||
|
||||
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
|
||||
randomNet :: (MonadRandom m) => m (Network '[('D1 2), ('D1 40), ('D1 40), ('D1 10), ('D1 10), ('D1 1), ('D1 1)])
|
||||
randomNet :: (MonadRandom m) => m (Network '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1])
|
||||
randomNet = do
|
||||
a :: FullyConnected 2 40 <- randomFullyConnected
|
||||
b :: FullyConnected 40 10 <- randomFullyConnected
|
||||
|
@ -30,16 +30,15 @@ import Grenade
|
||||
|
||||
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
|
||||
-- this network should get down to about a 1.3% error rate.
|
||||
randomMnistNet :: (MonadRandom m) => m (Network '[('D2 28 28), ('D2 32 32), ('D3 28 28 10), ('D3 14 14 10), ('D3 14 14 10), ('D3 10 10 16), ('D3 5 5 16), ('D1 400), ('D1 400), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
|
||||
randomMnistNet :: MonadRandom m => m (Network '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
|
||||
randomMnistNet = do
|
||||
let pad :: Pad 2 2 2 2 = Pad
|
||||
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
|
||||
let b :: Pooling 2 2 2 2 = Pooling
|
||||
c :: Convolution 10 16 5 5 1 1 <- randomConvolution
|
||||
let d :: Pooling 2 2 2 2 = Pooling
|
||||
e :: FullyConnected 400 80 <- randomFullyConnected
|
||||
e :: FullyConnected 256 80 <- randomFullyConnected
|
||||
f :: FullyConnected 80 10 <- randomFullyConnected
|
||||
return $ pad :~> a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
|
||||
return $ a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
|
||||
|
||||
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> IO ()
|
||||
convTest iterations trainFile validateFile rate = do
|
||||
@ -73,8 +72,8 @@ convTest iterations trainFile validateFile rate = do
|
||||
data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
|
||||
|
||||
mnist' :: Parser MnistOpts
|
||||
mnist' = MnistOpts <$> (argument str (metavar "TRAIN"))
|
||||
<*> (argument str (metavar "VALIDATE"))
|
||||
mnist' = MnistOpts <$> argument str (metavar "TRAIN")
|
||||
<*> argument str (metavar "VALIDATE")
|
||||
<*> option auto (long "iterations" <> short 'i' <> value 15)
|
||||
<*> (LearningParameters
|
||||
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||||
|
@ -39,17 +39,17 @@ data Shape =
|
||||
instance KnownShape x => Num (S' x) where
|
||||
(+) (S1D' x) (S1D' y) = S1D' (x + y)
|
||||
(+) (S2D' x) (S2D' y) = S2D' (x + y)
|
||||
(+) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' + y') x y)
|
||||
(+) (S3D' x) (S3D' y) = S3D' (vectorZip (+) x y)
|
||||
(+) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
(-) (S1D' x) (S1D' y) = S1D' (x - y)
|
||||
(-) (S2D' x) (S2D' y) = S2D' (x - y)
|
||||
(-) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' - y') x y)
|
||||
(-) (S3D' x) (S3D' y) = S3D' (vectorZip (-) x y)
|
||||
(-) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
(*) (S1D' x) (S1D' y) = S1D' (x * y)
|
||||
(*) (S2D' x) (S2D' y) = S2D' (x * y)
|
||||
(*) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' * y') x y)
|
||||
(*) (S3D' x) (S3D' y) = S3D' (vectorZip (*) x y)
|
||||
(*) _ _ = error "Impossible to have different constructors for the same shaped network"
|
||||
|
||||
abs (S1D' x) = S1D' (abs x)
|
||||
|
@ -26,7 +26,7 @@ instance Foldable (Vector n) where
|
||||
foldr f b (Vector as) = foldr f b as
|
||||
|
||||
instance KnownNat n => Traversable (Vector n) where
|
||||
traverse f (Vector as) = fmap mkVector $ traverse f as
|
||||
traverse f (Vector as) = mkVector <$> traverse f as
|
||||
|
||||
instance Functor (Vector n) where
|
||||
fmap f (Vector as) = Vector (fmap f as)
|
||||
@ -41,7 +41,7 @@ mkVector :: forall n a. KnownNat n => [a] -> Vector n a
|
||||
mkVector as
|
||||
= let du = fromIntegral . natVal $ (undefined :: Proxy n)
|
||||
la = length as
|
||||
in if (du == la)
|
||||
in if du == la
|
||||
then Vector as
|
||||
else error $ "Error creating staticly sized Vector of length: " ++
|
||||
show du ++ " list is of length:" ++ show la
|
||||
|
Loading…
Reference in New Issue
Block a user