grenade/main/feedforward.hs

81 lines
3.1 KiB
Haskell
Raw Normal View History

2016-06-23 15:12:57 +03:00
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
import Control.Monad
import Control.Monad.Random
import Data.List ( foldl' )
2016-06-23 15:12:57 +03:00
import GHC.TypeLits
import qualified Numeric.LinearAlgebra.Static as SA
import Options.Applicative
import Grenade
-- The defininition for our simple feed forward network.
-- The type level list represents the shapes passed through the layers. One can see that for this demonstration
-- we are using relu, tanh and logit non-linear units, which can be easily subsituted for each other in and out.
-- It's important to keep the type signatures, as there's many layers which can "squeeze" into the gaps
-- between the shapes, so inference can't do it all for us.
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
2016-12-13 02:06:40 +03:00
type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ]
'[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1]
randomNet :: MonadRandom m => m FFNet
randomNet = randomNetwork
2016-06-23 15:12:57 +03:00
netTest :: MonadRandom m => LearningParameters -> Int -> m String
2016-06-23 15:12:57 +03:00
netTest rate n = do
inps <- replicateM n $ do
s <- getRandom
return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
let outs = flip map inps $ \(S1D v) ->
if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
then S1D $ fromRational 1
else S1D $ fromRational 0
2016-06-23 15:12:57 +03:00
net0 <- randomNet
let trained = foldl' trainEach net0 (zip inps outs)
let testIns = [ [ (x,y) | x <- [0..50] ]
| y <- [0..20] ]
2016-06-23 15:12:57 +03:00
let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
return $ unlines outMat
2016-06-23 15:12:57 +03:00
where
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
2016-12-07 06:48:58 +03:00
trainEach !network (i,o) = train rate network i o
2016-06-23 15:12:57 +03:00
render n' | n' <= 0.2 = ' '
| n' <= 0.4 = '.'
| n' <= 0.6 = '-'
| n' <= 0.8 = '='
| otherwise = '#'
normx :: S ('D1 1) -> Double
normx (S1D r) = SA.mean r
2016-06-23 15:12:57 +03:00
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
2016-06-23 15:12:57 +03:00
feedForward' :: Parser FeedForwardOpts
feedForward' =
FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 100000)
<*> (LearningParameters
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
<*> option auto (long "momentum" <> value 0.9)
<*> option auto (long "l2" <> value 0.0005)
)
2016-06-23 15:12:57 +03:00
main :: IO ()
main = do
FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm)
putStrLn "Training network..."
putStrLn =<< evalRandIO (netTest rate examples)