mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Use ExceptT in mnist
This commit is contained in:
parent
3f7cf7f9a6
commit
2cb93c6a5b
@ -8,7 +8,6 @@
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Random
|
||||
|
||||
import GHC.TypeLits
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as SA
|
||||
@ -51,7 +50,7 @@ netTest rate n = do
|
||||
where
|
||||
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
|
||||
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
|
||||
trainEach !nt !(i, o) = train rate nt i o
|
||||
trainEach !network (i,o) = train rate network i o
|
||||
|
||||
render n' | n' <= 0.2 = ' '
|
||||
| n' <= 0.4 = '.'
|
||||
|
@ -9,6 +9,8 @@
|
||||
import Control.Applicative
|
||||
import Control.Monad
|
||||
import Control.Monad.Random
|
||||
import Control.Monad.Trans.Class
|
||||
import Control.Monad.Trans.Except
|
||||
|
||||
import qualified Data.Attoparsec.Text as A
|
||||
import qualified Data.Text as T
|
||||
@ -35,35 +37,23 @@ randomMnist :: MonadRandom m
|
||||
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
|
||||
randomMnist = randomNetwork
|
||||
|
||||
|
||||
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> IO ()
|
||||
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> ExceptT String IO ()
|
||||
convTest iterations trainFile validateFile rate = do
|
||||
net0 <- evalRandIO randomMnist
|
||||
fT <- T.readFile trainFile
|
||||
fV <- T.readFile validateFile
|
||||
let trainRows = traverse (A.parseOnly p) (T.lines fT)
|
||||
let validateRows = traverse (A.parseOnly p) (T.lines fV)
|
||||
case (trainRows, validateRows) of
|
||||
(Right tr', Right vr') -> foldM_ (runIteration tr' vr') net0 [1..iterations]
|
||||
err -> print err
|
||||
net0 <- lift randomMnist
|
||||
trainData <- readMNIST trainFile
|
||||
validateData <- readMNIST validateFile
|
||||
lift $ foldM_ (runIteration trainData validateData) net0 [1..iterations]
|
||||
|
||||
where
|
||||
trainEach !rate' !nt !(i, o) = train rate' nt i o
|
||||
where
|
||||
trainEach rate' !network (i, o) = train rate' network i o
|
||||
|
||||
p :: A.Parser (S' ('D2 28 28), S' ('D1 10))
|
||||
p = do
|
||||
lab <- A.decimal
|
||||
pixels <- many (A.char ',' >> A.double)
|
||||
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
|
||||
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
|
||||
|
||||
runIteration trainRows validateRows net i = do
|
||||
let trained' = foldl (trainEach rate) net trainRows
|
||||
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
|
||||
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||
print trained'
|
||||
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
||||
return trained'
|
||||
runIteration trainRows validateRows net i = do
|
||||
let trained' = foldl (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
|
||||
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
|
||||
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||
print trained'
|
||||
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
||||
return trained'
|
||||
|
||||
data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
|
||||
|
||||
@ -81,4 +71,20 @@ main :: IO ()
|
||||
main = do
|
||||
MnistOpts mnist vali iter rate <- execParser (info (mnist' <**> helper) idm)
|
||||
putStrLn "Training convolutional neural network..."
|
||||
convTest iter mnist vali rate
|
||||
|
||||
res <- runExceptT $ convTest iter mnist vali rate
|
||||
case res of
|
||||
Right () -> pure ()
|
||||
Left err -> putStrLn err
|
||||
|
||||
readMNIST :: FilePath -> ExceptT String IO [(S' ('D2 28 28), S' ('D1 10))]
|
||||
readMNIST mnist = ExceptT $ do
|
||||
mnistdata <- T.readFile mnist
|
||||
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
|
||||
|
||||
parseMNIST :: A.Parser (S' ('D2 28 28), S' ('D1 10))
|
||||
parseMNIST = do
|
||||
lab <- A.decimal
|
||||
pixels <- many (A.char ',' >> A.double)
|
||||
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
|
||||
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
|
||||
|
Loading…
Reference in New Issue
Block a user