Remove pad from MNIST example, tidy ups

This commit is contained in:
Huw Campbell 2016-12-03 00:04:01 +11:00
parent ca4b0fe912
commit dfbb6c17b8
5 changed files with 13 additions and 14 deletions

View File

@ -13,10 +13,10 @@ Five is right out.
Grenade is a dependently typed, practical, and pretty quick neural network library for concise and precise
specifications of complex networks in Haskell.
As an example, a network which can achieve less than 1.5% error on mnist can be specified and
As an example, a network which can achieve less than 1.5% error on MNIST can be specified and
initialised with random weights in under 10 lines of code with
```haskell
randomMnistNet :: MonadRandom m => m (Network Identity '[('D2 28 28), ('D3 24 24 10), ('D3 12 12 10), ('D3 12 12 10), ('D3 8 8 16), ('D3 4 4 16), ('D1 256), ('D1 256), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
randomMnistNet :: MonadRandom m => m (Network '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
randomMnistNet = do
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
let b :: Pooling 2 2 2 2 = Pooling

View File

@ -26,7 +26,7 @@ import Grenade
-- between the shapes, so inference can't do it all for us.
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
randomNet :: (MonadRandom m) => m (Network '[('D1 2), ('D1 40), ('D1 40), ('D1 10), ('D1 10), ('D1 1), ('D1 1)])
randomNet :: (MonadRandom m) => m (Network '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1])
randomNet = do
a :: FullyConnected 2 40 <- randomFullyConnected
b :: FullyConnected 40 10 <- randomFullyConnected

View File

@ -30,16 +30,15 @@ import Grenade
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
-- this network should get down to about a 1.3% error rate.
randomMnistNet :: (MonadRandom m) => m (Network '[('D2 28 28), ('D2 32 32), ('D3 28 28 10), ('D3 14 14 10), ('D3 14 14 10), ('D3 10 10 16), ('D3 5 5 16), ('D1 400), ('D1 400), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
randomMnistNet :: MonadRandom m => m (Network '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
randomMnistNet = do
let pad :: Pad 2 2 2 2 = Pad
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
let b :: Pooling 2 2 2 2 = Pooling
c :: Convolution 10 16 5 5 1 1 <- randomConvolution
let d :: Pooling 2 2 2 2 = Pooling
e :: FullyConnected 400 80 <- randomFullyConnected
e :: FullyConnected 256 80 <- randomFullyConnected
f :: FullyConnected 80 10 <- randomFullyConnected
return $ pad :~> a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
return $ a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> IO ()
convTest iterations trainFile validateFile rate = do
@ -73,8 +72,8 @@ convTest iterations trainFile validateFile rate = do
data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
mnist' :: Parser MnistOpts
mnist' = MnistOpts <$> (argument str (metavar "TRAIN"))
<*> (argument str (metavar "VALIDATE"))
mnist' = MnistOpts <$> argument str (metavar "TRAIN")
<*> argument str (metavar "VALIDATE")
<*> option auto (long "iterations" <> short 'i' <> value 15)
<*> (LearningParameters
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)

View File

@ -39,17 +39,17 @@ data Shape =
instance KnownShape x => Num (S' x) where
(+) (S1D' x) (S1D' y) = S1D' (x + y)
(+) (S2D' x) (S2D' y) = S2D' (x + y)
(+) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' + y') x y)
(+) (S3D' x) (S3D' y) = S3D' (vectorZip (+) x y)
(+) _ _ = error "Impossible to have different constructors for the same shaped network"
(-) (S1D' x) (S1D' y) = S1D' (x - y)
(-) (S2D' x) (S2D' y) = S2D' (x - y)
(-) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' - y') x y)
(-) (S3D' x) (S3D' y) = S3D' (vectorZip (-) x y)
(-) _ _ = error "Impossible to have different constructors for the same shaped network"
(*) (S1D' x) (S1D' y) = S1D' (x * y)
(*) (S2D' x) (S2D' y) = S2D' (x * y)
(*) (S3D' x) (S3D' y) = S3D' (vectorZip (\x' y' -> x' * y') x y)
(*) (S3D' x) (S3D' y) = S3D' (vectorZip (*) x y)
(*) _ _ = error "Impossible to have different constructors for the same shaped network"
abs (S1D' x) = S1D' (abs x)

View File

@ -26,7 +26,7 @@ instance Foldable (Vector n) where
foldr f b (Vector as) = foldr f b as
instance KnownNat n => Traversable (Vector n) where
traverse f (Vector as) = fmap mkVector $ traverse f as
traverse f (Vector as) = mkVector <$> traverse f as
instance Functor (Vector n) where
fmap f (Vector as) = Vector (fmap f as)
@ -41,7 +41,7 @@ mkVector :: forall n a. KnownNat n => [a] -> Vector n a
mkVector as
= let du = fromIntegral . natVal $ (undefined :: Proxy n)
la = length as
in if (du == la)
in if du == la
then Vector as
else error $ "Error creating staticly sized Vector of length: " ++
show du ++ " list is of length:" ++ show la