mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-29 03:31:58 +03:00
84 lines
3.1 KiB
Haskell
84 lines
3.1 KiB
Haskell
|
{-# LANGUAGE BangPatterns #-}
|
||
|
{-# LANGUAGE DataKinds #-}
|
||
|
{-# LANGUAGE KindSignatures #-}
|
||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||
|
{-# LANGUAGE TypeOperators #-}
|
||
|
{-# LANGUAGE TupleSections #-}
|
||
|
{-# LANGUAGE TypeFamilies #-}
|
||
|
{-# LANGUAGE FlexibleContexts #-}
|
||
|
|
||
|
import Control.Monad
|
||
|
import Control.Monad.Identity
|
||
|
import Control.Monad.Random
|
||
|
|
||
|
import GHC.TypeLits
|
||
|
|
||
|
import qualified Numeric.LinearAlgebra.Static as SA
|
||
|
|
||
|
import Options.Applicative
|
||
|
|
||
|
import Grenade
|
||
|
|
||
|
|
||
|
-- The defininition for our simple feed forward network.
|
||
|
-- The type level list represents the shapes passed through the layers. One can see that for this demonstration
|
||
|
-- we are using relu, tanh and logit non-linear units, which can be easily subsituted for each other in and out.
|
||
|
|
||
|
-- It's important to keep the type signatures, as there's many layers which can "squeeze" into the gaps
|
||
|
-- between the shapes, so inference can't do it all for us.
|
||
|
|
||
|
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
|
||
|
randomNet :: (MonadRandom m) => m (Network Identity '[('D1 2), ('D1 40), ('D1 40), ('D1 10), ('D1 10), ('D1 1), ('D1 1)])
|
||
|
randomNet = do
|
||
|
a :: FullyConnected 2 40 <- randomFullyConnected
|
||
|
b :: FullyConnected 40 10 <- randomFullyConnected
|
||
|
c :: FullyConnected 10 1 <- randomFullyConnected
|
||
|
return $ a :~> Tanh :~> b :~> Relu :~> c :~> O Logit
|
||
|
|
||
|
netTest :: MonadRandom m => Double -> Int -> m String
|
||
|
netTest rate n = do
|
||
|
inps <- replicateM n $ do
|
||
|
s <- getRandom
|
||
|
return $ S1D' $ SA.randomVector s SA.Uniform * 2 - 1
|
||
|
let outs = flip map inps $ \(S1D' v) ->
|
||
|
if v `inCircle` (fromRational 0.33, 0.33)
|
||
|
|| v `inCircle` (fromRational (-0.33), 0.33)
|
||
|
then S1D' $ fromRational 1
|
||
|
else S1D' $ fromRational 0
|
||
|
net0 <- randomNet
|
||
|
|
||
|
return . runIdentity $ do
|
||
|
trained <- foldM trainEach net0 (zip inps outs)
|
||
|
let testIns = [ [ (x,y) | x <- [0..50] ]
|
||
|
| y <- [0..20] ]
|
||
|
|
||
|
outMat <- traverse (traverse (\(x,y) -> (render . normx) <$> runNet trained (S1D' $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
|
||
|
return $ unlines outMat
|
||
|
|
||
|
where
|
||
|
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
|
||
|
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
|
||
|
trainEach !nt !(i, o) = train rate i o nt
|
||
|
|
||
|
render n' | n' <= 0.2 = ' '
|
||
|
| n' <= 0.4 = '.'
|
||
|
| n' <= 0.6 = '-'
|
||
|
| n' <= 0.8 = '='
|
||
|
| otherwise = '#'
|
||
|
|
||
|
normx :: S' ('D1 1) -> Double
|
||
|
normx (S1D' r) = SA.mean r
|
||
|
|
||
|
|
||
|
data FeedForwardOpts = FeedForwardOpts Int Double
|
||
|
|
||
|
feedForward' :: Parser FeedForwardOpts
|
||
|
feedForward' = FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 1000000)
|
||
|
<*> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
||
|
|
||
|
main :: IO ()
|
||
|
main = do
|
||
|
FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm)
|
||
|
putStrLn "Training network..."
|
||
|
putStrLn =<< evalRandIO (netTest rate examples)
|