mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Added Accuracy and HyperParamAccuracy types
They keep track of LearningParameters and the corresponding accuracies.
This commit is contained in:
parent
056c4228d3
commit
e37b0fb017
@ -43,6 +43,7 @@ library
|
||||
, containers >= 0.5 && < 0.6
|
||||
, cereal >= 0.5 && < 0.6
|
||||
, deepseq >= 1.4 && < 1.5
|
||||
, exceptions
|
||||
, hmatrix == 0.18.*
|
||||
, MonadRandom >= 0.4 && < 0.6
|
||||
, primitive >= 0.6 && < 0.7
|
||||
@ -106,6 +107,8 @@ library
|
||||
Grenade.Recurrent.Layers.LSTM
|
||||
|
||||
Grenade.Utils.OneHot
|
||||
Grenade.Utils.Accuracy
|
||||
Grenade.Utils.Accuracy.Internal
|
||||
|
||||
includes: cbits/im2col.h
|
||||
cbits/gradient_descent.h
|
||||
|
30
src/Grenade/Utils/Accuracy.hs
Normal file
30
src/Grenade/Utils/Accuracy.hs
Normal file
@ -0,0 +1,30 @@
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
|
||||
module Grenade.Utils.Accuracy
|
||||
( Accuracy
|
||||
, HyperParamAccuracy(..)
|
||||
, accuracyM
|
||||
) where
|
||||
|
||||
import Grenade.Core.LearningParameters
|
||||
|
||||
import Grenade.Utils.Accuracy.Internal
|
||||
|
||||
import GHC.Generics
|
||||
|
||||
import Data.Validity
|
||||
|
||||
import Data.Aeson (ToJSON, FromJSON)
|
||||
|
||||
data HyperParamAccuracy = HyperParamAccuracy
|
||||
{ hyperParam :: LearningParameters
|
||||
, testAccuracies :: [Accuracy]
|
||||
, validationAccuracies :: [Accuracy]
|
||||
, trainAccuracies :: [Accuracy]
|
||||
} deriving (Show, Eq, Generic)
|
||||
|
||||
instance ToJSON HyperParamAccuracy
|
||||
|
||||
instance FromJSON HyperParamAccuracy
|
||||
|
||||
instance Validity HyperParamAccuracy
|
30
src/Grenade/Utils/Accuracy/Internal.hs
Normal file
30
src/Grenade/Utils/Accuracy/Internal.hs
Normal file
@ -0,0 +1,30 @@
|
||||
{-# LANGUAGE DeriveGeneric #-}
|
||||
|
||||
module Grenade.Utils.Accuracy.Internal where
|
||||
|
||||
import Data.Validity
|
||||
|
||||
import Control.Monad.Catch
|
||||
|
||||
import Data.Aeson (ToJSON, FromJSON)
|
||||
|
||||
import GHC.Generics
|
||||
|
||||
newtype Accuracy = Accuracy Double deriving (Show, Eq, Generic)
|
||||
|
||||
data AccuracyNotInRange = AccuracyNotInRange deriving (Show, Eq)
|
||||
|
||||
instance Exception AccuracyNotInRange where
|
||||
displayException AccuracyNotInRange = "The accuracy is not in [0,1]."
|
||||
|
||||
accuracyM :: MonadThrow m => Double -> m Accuracy
|
||||
accuracyM x = case 0 <= x && x <= 1 of
|
||||
False -> throwM AccuracyNotInRange
|
||||
True -> pure $ Accuracy x
|
||||
|
||||
instance ToJSON Accuracy
|
||||
|
||||
instance FromJSON Accuracy
|
||||
|
||||
instance Validity Accuracy where
|
||||
validate (Accuracy x) = 0 <= x && x <= 1 <?@> "The accuracy is in [0,1]"
|
@ -6,6 +6,10 @@ import Grenade.Core.LearningParameters
|
||||
|
||||
import Data.GenValidity
|
||||
|
||||
import Grenade.Utils.Accuracy
|
||||
import Grenade.Utils.Accuracy.Internal
|
||||
|
||||
import Test.QuickCheck (choose, listOf)
|
||||
import Test.QuickCheck.Gen (suchThat)
|
||||
|
||||
instance GenUnchecked LearningParameters
|
||||
@ -15,3 +19,17 @@ instance GenValid LearningParameters where
|
||||
rate <- genValid `suchThat` (> 0)
|
||||
momentum <- genValid `suchThat` (>= 0)
|
||||
LearningParameters rate momentum <$> genValid `suchThat` (>= 0)
|
||||
|
||||
instance GenUnchecked Accuracy
|
||||
|
||||
instance GenValid Accuracy where
|
||||
genValid = Accuracy <$> choose (0,1)
|
||||
|
||||
instance GenUnchecked HyperParamAccuracy
|
||||
|
||||
instance GenValid HyperParamAccuracy where
|
||||
genValid = do
|
||||
param <- genValid
|
||||
testAcc <- listOf genValid
|
||||
validationAcc <- listOf genValid
|
||||
HyperParamAccuracy param testAcc validationAcc <$> listOf genValid
|
||||
|
@ -8,6 +8,7 @@ import Test.Hspec
|
||||
import Test.Grenade.Gen ()
|
||||
|
||||
import Grenade.Core.LearningParameters
|
||||
import Grenade.Utils.Accuracy
|
||||
|
||||
import Test.Validity
|
||||
import Test.Validity.Aeson
|
||||
@ -19,3 +20,7 @@ spec :: Spec
|
||||
spec = do
|
||||
genValidSpec @LearningParameters
|
||||
jsonSpecOnValid @LearningParameters
|
||||
genValidSpec @Accuracy
|
||||
jsonSpecOnValid @Accuracy
|
||||
genValidSpec @HyperParamAccuracy
|
||||
jsonSpecOnValid @HyperParamAccuracy
|
||||
|
Loading…
Reference in New Issue
Block a user