mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Implement Dropout layer
This commit is contained in:
parent
3d1919fab4
commit
e710046a3b
@ -33,7 +33,7 @@ class Layer (m :: * -> *) x (i :: Shape) (o :: Shape) where
|
||||
-- | Type of a network.
|
||||
-- The [Shape] type specifies the shapes of data passed between the layers.
|
||||
-- Could be considered to be a heterogeneous list of layers which are able to
|
||||
-- transform the date shapes of the network.
|
||||
-- transform the data shapes of the network.
|
||||
data Network :: (* -> *) -> [Shape] -> * where
|
||||
O :: (Show x, Layer m x i o, KnownShape o, KnownShape i)
|
||||
=> !x
|
||||
|
@ -19,6 +19,7 @@ import Grenade.Core.Network
|
||||
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
|
||||
-- Dropout layer help to reduce overfitting.
|
||||
-- Idea here is that the vector is a shape of 1s and 0s, which we multiply the input by.
|
||||
-- After backpropogation, we return a new matrix/vector, with different bits dropped out.
|
||||
@ -27,6 +28,16 @@ import Numeric.LinearAlgebra.Static
|
||||
data Dropout o = Dropout Double (R o)
|
||||
deriving Show
|
||||
|
||||
randomDropout :: (MonadRandom m, KnownNat i)
|
||||
=> Double -> m (Dropout i)
|
||||
randomDropout rate = do
|
||||
seed <- getRandom
|
||||
let wN = randomVector seed Uniform
|
||||
xs = dvmap (\a -> if a <= rate then 0 else 1) wN
|
||||
return $ Dropout rate xs
|
||||
|
||||
instance (MonadRandom m, KnownNat i) => Layer m (Dropout i) ('D1 i) ('D1 i) where
|
||||
runForwards _ _= error "todo"
|
||||
runBackards _ _ _ _ = error "todo"
|
||||
runForwards (Dropout _ drops) (S1D' x) = return . S1D' $ x * drops
|
||||
runBackards _ (Dropout rate drops) _ (S1D' x) = do
|
||||
newDropout <- randomDropout rate
|
||||
return (newDropout, S1D' $ x * drops)
|
||||
|
Loading…
Reference in New Issue
Block a user