2016-06-23 15:12:57 +03:00
|
|
|
{-# LANGUAGE BangPatterns #-}
|
|
|
|
{-# LANGUAGE DataKinds #-}
|
|
|
|
{-# LANGUAGE ScopedTypeVariables #-}
|
|
|
|
{-# LANGUAGE TypeOperators #-}
|
|
|
|
{-# LANGUAGE TupleSections #-}
|
|
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
|
|
import Control.Monad
|
|
|
|
import Control.Monad.Random
|
2016-12-20 08:31:09 +03:00
|
|
|
import Data.List ( foldl' )
|
|
|
|
|
2016-12-10 01:31:01 +03:00
|
|
|
import qualified Data.ByteString as B
|
|
|
|
import Data.Serialize
|
2017-02-03 13:04:27 +03:00
|
|
|
import Data.Semigroup ( (<>) )
|
2016-12-10 01:31:01 +03:00
|
|
|
|
2016-06-23 15:12:57 +03:00
|
|
|
import GHC.TypeLits
|
|
|
|
|
|
|
|
import qualified Numeric.LinearAlgebra.Static as SA
|
|
|
|
|
|
|
|
import Options.Applicative
|
|
|
|
|
|
|
|
import Grenade
|
|
|
|
|
|
|
|
|
2016-12-10 01:31:01 +03:00
|
|
|
-- The defininition for our simple feed forward network.
|
|
|
|
-- The type level lists represents the layers and the shapes passed through the layers.
|
|
|
|
-- One can see that for this demonstration we are using relu, tanh and logit non-linear
|
|
|
|
-- units, which can be easily subsituted for each other in and out.
|
|
|
|
--
|
2016-06-23 15:12:57 +03:00
|
|
|
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
|
2016-12-13 02:06:40 +03:00
|
|
|
type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ]
|
|
|
|
'[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1]
|
|
|
|
|
|
|
|
randomNet :: MonadRandom m => m FFNet
|
2016-12-05 05:46:24 +03:00
|
|
|
randomNet = randomNetwork
|
2016-06-23 15:12:57 +03:00
|
|
|
|
2016-12-10 01:31:01 +03:00
|
|
|
netTrain :: FFNet -> LearningParameters -> Int -> IO FFNet
|
|
|
|
netTrain net0 rate n = do
|
2016-06-23 15:12:57 +03:00
|
|
|
inps <- replicateM n $ do
|
2016-12-02 10:22:35 +03:00
|
|
|
s <- getRandom
|
2016-12-20 08:31:09 +03:00
|
|
|
return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
|
|
|
|
let outs = flip map inps $ \(S1D v) ->
|
2016-12-02 10:22:35 +03:00
|
|
|
if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
|
2016-12-20 08:31:09 +03:00
|
|
|
then S1D $ fromRational 1
|
|
|
|
else S1D $ fromRational 0
|
2016-06-23 15:12:57 +03:00
|
|
|
|
2016-12-20 08:31:09 +03:00
|
|
|
let trained = foldl' trainEach net0 (zip inps outs)
|
2016-12-10 01:31:01 +03:00
|
|
|
return trained
|
2016-06-23 15:12:57 +03:00
|
|
|
|
|
|
|
where
|
|
|
|
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
|
|
|
|
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
|
2016-12-07 06:48:58 +03:00
|
|
|
trainEach !network (i,o) = train rate network i o
|
2016-06-23 15:12:57 +03:00
|
|
|
|
2016-12-10 01:31:01 +03:00
|
|
|
netLoad :: FilePath -> IO FFNet
|
|
|
|
netLoad modelPath = do
|
|
|
|
modelData <- B.readFile modelPath
|
|
|
|
either fail return $ runGet (get :: Get FFNet) modelData
|
|
|
|
|
|
|
|
netScore :: FFNet -> IO ()
|
|
|
|
netScore network = do
|
|
|
|
let testIns = [ [ (x,y) | x <- [0..50] ]
|
|
|
|
| y <- [0..20] ]
|
|
|
|
outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet network (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
|
|
|
|
putStrLn $ unlines outMat
|
|
|
|
|
|
|
|
where
|
2016-06-23 15:12:57 +03:00
|
|
|
render n' | n' <= 0.2 = ' '
|
|
|
|
| n' <= 0.4 = '.'
|
|
|
|
| n' <= 0.6 = '-'
|
|
|
|
| n' <= 0.8 = '='
|
|
|
|
| otherwise = '#'
|
|
|
|
|
2016-12-20 08:31:09 +03:00
|
|
|
normx :: S ('D1 1) -> Double
|
|
|
|
normx (S1D r) = SA.mean r
|
2016-06-23 15:12:57 +03:00
|
|
|
|
2016-12-10 01:31:01 +03:00
|
|
|
data FeedForwardOpts = FeedForwardOpts Int LearningParameters (Maybe FilePath) (Maybe FilePath)
|
2016-06-23 15:12:57 +03:00
|
|
|
|
|
|
|
feedForward' :: Parser FeedForwardOpts
|
2016-12-02 10:22:35 +03:00
|
|
|
feedForward' =
|
2016-12-02 15:44:29 +03:00
|
|
|
FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 100000)
|
2016-12-02 10:22:35 +03:00
|
|
|
<*> (LearningParameters
|
|
|
|
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
|
|
|
|
<*> option auto (long "momentum" <> value 0.9)
|
2016-12-04 03:20:00 +03:00
|
|
|
<*> option auto (long "l2" <> value 0.0005)
|
2016-12-02 10:22:35 +03:00
|
|
|
)
|
2016-12-10 01:31:01 +03:00
|
|
|
<*> optional (strOption (long "load"))
|
|
|
|
<*> optional (strOption (long "save"))
|
2016-06-23 15:12:57 +03:00
|
|
|
|
|
|
|
main :: IO ()
|
|
|
|
main = do
|
2016-12-10 01:31:01 +03:00
|
|
|
FeedForwardOpts examples rate load save <- execParser (info (feedForward' <**> helper) idm)
|
|
|
|
net0 <- case load of
|
|
|
|
Just loadFile -> netLoad loadFile
|
|
|
|
Nothing -> randomNet
|
|
|
|
|
|
|
|
net <- netTrain net0 rate examples
|
|
|
|
netScore net
|
|
|
|
|
|
|
|
case save of
|
|
|
|
Just saveFile -> B.writeFile saveFile $ runPut (put net)
|
|
|
|
Nothing -> return ()
|