Added Accuracy and HyperParamAccuracy types

They keep track of LearningParameters and the corresponding accuracies.
This commit is contained in:
Nick Van den Broeck 2018-04-17 16:15:39 +02:00
parent 056c4228d3
commit e37b0fb017
5 changed files with 86 additions and 0 deletions

View File

@ -43,6 +43,7 @@ library
, containers >= 0.5 && < 0.6
, cereal >= 0.5 && < 0.6
, deepseq >= 1.4 && < 1.5
, exceptions
, hmatrix == 0.18.*
, MonadRandom >= 0.4 && < 0.6
, primitive >= 0.6 && < 0.7
@ -106,6 +107,8 @@ library
Grenade.Recurrent.Layers.LSTM
Grenade.Utils.OneHot
Grenade.Utils.Accuracy
Grenade.Utils.Accuracy.Internal
includes: cbits/im2col.h
cbits/gradient_descent.h

View File

@ -0,0 +1,30 @@
{-# LANGUAGE DeriveGeneric #-}
module Grenade.Utils.Accuracy
( Accuracy
, HyperParamAccuracy(..)
, accuracyM
) where
import Grenade.Core.LearningParameters
import Grenade.Utils.Accuracy.Internal
import GHC.Generics
import Data.Validity
import Data.Aeson (ToJSON, FromJSON)
data HyperParamAccuracy = HyperParamAccuracy
{ hyperParam :: LearningParameters
, testAccuracies :: [Accuracy]
, validationAccuracies :: [Accuracy]
, trainAccuracies :: [Accuracy]
} deriving (Show, Eq, Generic)
instance ToJSON HyperParamAccuracy
instance FromJSON HyperParamAccuracy
instance Validity HyperParamAccuracy

View File

@ -0,0 +1,30 @@
{-# LANGUAGE DeriveGeneric #-}
module Grenade.Utils.Accuracy.Internal where
import Data.Validity
import Control.Monad.Catch
import Data.Aeson (ToJSON, FromJSON)
import GHC.Generics
newtype Accuracy = Accuracy Double deriving (Show, Eq, Generic)
data AccuracyNotInRange = AccuracyNotInRange deriving (Show, Eq)
instance Exception AccuracyNotInRange where
displayException AccuracyNotInRange = "The accuracy is not in [0,1]."
accuracyM :: MonadThrow m => Double -> m Accuracy
accuracyM x = case 0 <= x && x <= 1 of
False -> throwM AccuracyNotInRange
True -> pure $ Accuracy x
instance ToJSON Accuracy
instance FromJSON Accuracy
instance Validity Accuracy where
validate (Accuracy x) = 0 <= x && x <= 1 <?@> "The accuracy is in [0,1]"

View File

@ -6,6 +6,10 @@ import Grenade.Core.LearningParameters
import Data.GenValidity
import Grenade.Utils.Accuracy
import Grenade.Utils.Accuracy.Internal
import Test.QuickCheck (choose, listOf)
import Test.QuickCheck.Gen (suchThat)
instance GenUnchecked LearningParameters
@ -15,3 +19,17 @@ instance GenValid LearningParameters where
rate <- genValid `suchThat` (> 0)
momentum <- genValid `suchThat` (>= 0)
LearningParameters rate momentum <$> genValid `suchThat` (>= 0)
instance GenUnchecked Accuracy
instance GenValid Accuracy where
genValid = Accuracy <$> choose (0,1)
instance GenUnchecked HyperParamAccuracy
instance GenValid HyperParamAccuracy where
genValid = do
param <- genValid
testAcc <- listOf genValid
validationAcc <- listOf genValid
HyperParamAccuracy param testAcc validationAcc <$> listOf genValid

View File

@ -8,6 +8,7 @@ import Test.Hspec
import Test.Grenade.Gen ()
import Grenade.Core.LearningParameters
import Grenade.Utils.Accuracy
import Test.Validity
import Test.Validity.Aeson
@ -19,3 +20,7 @@ spec :: Spec
spec = do
genValidSpec @LearningParameters
jsonSpecOnValid @LearningParameters
genValidSpec @Accuracy
jsonSpecOnValid @Accuracy
genValidSpec @HyperParamAccuracy
jsonSpecOnValid @HyperParamAccuracy