Use ExceptT in mnist

This commit is contained in:
Huw Campbell 2016-12-07 14:48:58 +11:00
parent 3f7cf7f9a6
commit 2cb93c6a5b
2 changed files with 34 additions and 29 deletions

View File

@ -8,7 +8,6 @@
import Control.Monad
import Control.Monad.Random
import GHC.TypeLits
import qualified Numeric.LinearAlgebra.Static as SA
@ -51,7 +50,7 @@ netTest rate n = do
where
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
trainEach !nt !(i, o) = train rate nt i o
trainEach !network (i,o) = train rate network i o
render n' | n' <= 0.2 = ' '
| n' <= 0.4 = '.'

View File

@ -9,6 +9,8 @@
import Control.Applicative
import Control.Monad
import Control.Monad.Random
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import qualified Data.Attoparsec.Text as A
import qualified Data.Text as T
@ -35,35 +37,23 @@ randomMnist :: MonadRandom m
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
randomMnist = randomNetwork
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> IO ()
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> ExceptT String IO ()
convTest iterations trainFile validateFile rate = do
net0 <- evalRandIO randomMnist
fT <- T.readFile trainFile
fV <- T.readFile validateFile
let trainRows = traverse (A.parseOnly p) (T.lines fT)
let validateRows = traverse (A.parseOnly p) (T.lines fV)
case (trainRows, validateRows) of
(Right tr', Right vr') -> foldM_ (runIteration tr' vr') net0 [1..iterations]
err -> print err
net0 <- lift randomMnist
trainData <- readMNIST trainFile
validateData <- readMNIST validateFile
lift $ foldM_ (runIteration trainData validateData) net0 [1..iterations]
where
trainEach !rate' !nt !(i, o) = train rate' nt i o
where
trainEach rate' !network (i, o) = train rate' network i o
p :: A.Parser (S' ('D2 28 28), S' ('D1 10))
p = do
lab <- A.decimal
pixels <- many (A.char ',' >> A.double)
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
runIteration trainRows validateRows net i = do
let trained' = foldl (trainEach rate) net trainRows
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
print trained'
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
return trained'
runIteration trainRows validateRows net i = do
let trained' = foldl (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
print trained'
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
return trained'
data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
@ -81,4 +71,20 @@ main :: IO ()
main = do
MnistOpts mnist vali iter rate <- execParser (info (mnist' <**> helper) idm)
putStrLn "Training convolutional neural network..."
convTest iter mnist vali rate
res <- runExceptT $ convTest iter mnist vali rate
case res of
Right () -> pure ()
Left err -> putStrLn err
readMNIST :: FilePath -> ExceptT String IO [(S' ('D2 28 28), S' ('D1 10))]
readMNIST mnist = ExceptT $ do
mnistdata <- T.readFile mnist
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
parseMNIST :: A.Parser (S' ('D2 28 28), S' ('D1 10))
parseMNIST = do
lab <- A.decimal
pixels <- many (A.char ',' >> A.double)
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')