diff --git a/grenade.cabal b/grenade.cabal index f8e90dc..97e0164 100644 --- a/grenade.cabal +++ b/grenade.cabal @@ -43,6 +43,7 @@ library , containers >= 0.5 && < 0.6 , cereal >= 0.5 && < 0.6 , deepseq >= 1.4 && < 1.5 + , exceptions , hmatrix == 0.18.* , MonadRandom >= 0.4 && < 0.6 , primitive >= 0.6 && < 0.7 @@ -106,6 +107,8 @@ library Grenade.Recurrent.Layers.LSTM Grenade.Utils.OneHot + Grenade.Utils.Accuracy + Grenade.Utils.Accuracy.Internal includes: cbits/im2col.h cbits/gradient_descent.h diff --git a/src/Grenade/Utils/Accuracy.hs b/src/Grenade/Utils/Accuracy.hs new file mode 100644 index 0000000..9fa2f69 --- /dev/null +++ b/src/Grenade/Utils/Accuracy.hs @@ -0,0 +1,30 @@ +{-# LANGUAGE DeriveGeneric #-} + +module Grenade.Utils.Accuracy + ( Accuracy + , HyperParamAccuracy(..) + , accuracyM + ) where + +import Grenade.Core.LearningParameters + +import Grenade.Utils.Accuracy.Internal + +import GHC.Generics + +import Data.Validity + +import Data.Aeson (ToJSON, FromJSON) + +data HyperParamAccuracy = HyperParamAccuracy + { hyperParam :: LearningParameters + , testAccuracies :: [Accuracy] + , validationAccuracies :: [Accuracy] + , trainAccuracies :: [Accuracy] + } deriving (Show, Eq, Generic) + +instance ToJSON HyperParamAccuracy + +instance FromJSON HyperParamAccuracy + +instance Validity HyperParamAccuracy diff --git a/src/Grenade/Utils/Accuracy/Internal.hs b/src/Grenade/Utils/Accuracy/Internal.hs new file mode 100644 index 0000000..e9ceb76 --- /dev/null +++ b/src/Grenade/Utils/Accuracy/Internal.hs @@ -0,0 +1,30 @@ +{-# LANGUAGE DeriveGeneric #-} + +module Grenade.Utils.Accuracy.Internal where + +import Data.Validity + +import Control.Monad.Catch + +import Data.Aeson (ToJSON, FromJSON) + +import GHC.Generics + +newtype Accuracy = Accuracy Double deriving (Show, Eq, Generic) + +data AccuracyNotInRange = AccuracyNotInRange deriving (Show, Eq) + +instance Exception AccuracyNotInRange where + displayException AccuracyNotInRange = "The accuracy is not in [0,1]." + +accuracyM :: MonadThrow m => Double -> m Accuracy +accuracyM x = case 0 <= x && x <= 1 of + False -> throwM AccuracyNotInRange + True -> pure $ Accuracy x + +instance ToJSON Accuracy + +instance FromJSON Accuracy + +instance Validity Accuracy where + validate (Accuracy x) = 0 <= x && x <= 1 "The accuracy is in [0,1]" diff --git a/test/Test/Grenade/Gen.hs b/test/Test/Grenade/Gen.hs index a99c1b5..7542e77 100644 --- a/test/Test/Grenade/Gen.hs +++ b/test/Test/Grenade/Gen.hs @@ -6,6 +6,10 @@ import Grenade.Core.LearningParameters import Data.GenValidity +import Grenade.Utils.Accuracy +import Grenade.Utils.Accuracy.Internal + +import Test.QuickCheck (choose, listOf) import Test.QuickCheck.Gen (suchThat) instance GenUnchecked LearningParameters @@ -15,3 +19,17 @@ instance GenValid LearningParameters where rate <- genValid `suchThat` (> 0) momentum <- genValid `suchThat` (>= 0) LearningParameters rate momentum <$> genValid `suchThat` (>= 0) + +instance GenUnchecked Accuracy + +instance GenValid Accuracy where + genValid = Accuracy <$> choose (0,1) + +instance GenUnchecked HyperParamAccuracy + +instance GenValid HyperParamAccuracy where + genValid = do + param <- genValid + testAcc <- listOf genValid + validationAcc <- listOf genValid + HyperParamAccuracy param testAcc validationAcc <$> listOf genValid diff --git a/test/Test/Grenade/InstanceSpec.hs b/test/Test/Grenade/InstanceSpec.hs index 9e27b51..73a1d8e 100644 --- a/test/Test/Grenade/InstanceSpec.hs +++ b/test/Test/Grenade/InstanceSpec.hs @@ -8,6 +8,7 @@ import Test.Hspec import Test.Grenade.Gen () import Grenade.Core.LearningParameters +import Grenade.Utils.Accuracy import Test.Validity import Test.Validity.Aeson @@ -19,3 +20,7 @@ spec :: Spec spec = do genValidSpec @LearningParameters jsonSpecOnValid @LearningParameters + genValidSpec @Accuracy + jsonSpecOnValid @Accuracy + genValidSpec @HyperParamAccuracy + jsonSpecOnValid @HyperParamAccuracy