mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
GHC 8.4 compatibility
- `-XNoStarIsType` is new in GHC 8.6 - can't use `-XCPP` to conditionally enable `LANGUAGE` pragmas (implicit module declaration) - use a conditional in the .cabal file, but this is active for all modules - need to replace all `*` with `Type` and `import Data.Kind (Type)` (works in 8.4)
This commit is contained in:
parent
0bc6bd5fad
commit
49a4280a91
@ -58,6 +58,8 @@ library
|
|||||||
if impl(ghc < 8.0)
|
if impl(ghc < 8.0)
|
||||||
ghc-options: -fno-warn-incomplete-patterns
|
ghc-options: -fno-warn-incomplete-patterns
|
||||||
|
|
||||||
|
if impl(ghc >= 8.6)
|
||||||
|
default-extensions: NoStarIsType
|
||||||
|
|
||||||
exposed-modules:
|
exposed-modules:
|
||||||
Grenade
|
Grenade
|
||||||
|
@ -41,6 +41,8 @@ import Control.Monad.Random ( MonadRandom )
|
|||||||
|
|
||||||
import Data.List ( foldl' )
|
import Data.List ( foldl' )
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
import Grenade.Core.LearningParameters
|
import Grenade.Core.LearningParameters
|
||||||
|
|
||||||
@ -50,7 +52,7 @@ import Grenade.Core.LearningParameters
|
|||||||
class UpdateLayer x where
|
class UpdateLayer x where
|
||||||
-- | The type for the gradient for this layer.
|
-- | The type for the gradient for this layer.
|
||||||
-- Unit if there isn't a gradient to pass back.
|
-- Unit if there isn't a gradient to pass back.
|
||||||
type Gradient x :: *
|
type Gradient x :: Type
|
||||||
|
|
||||||
-- | Update a layer with its gradient and learning parameters
|
-- | Update a layer with its gradient and learning parameters
|
||||||
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
||||||
@ -72,7 +74,7 @@ class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
|
|||||||
-- | The Wengert tape for this layer. Includes all that is required
|
-- | The Wengert tape for this layer. Includes all that is required
|
||||||
-- to generate the back propagated gradients efficiently. As a
|
-- to generate the back propagated gradients efficiently. As a
|
||||||
-- default, `S i` is fine.
|
-- default, `S i` is fine.
|
||||||
type Tape x i o :: *
|
type Tape x i o :: Type
|
||||||
|
|
||||||
-- | Used in training and scoring. Take the input from the previous
|
-- | Used in training and scoring. Take the input from the previous
|
||||||
-- layer, and give the output from this layer.
|
-- layer, and give the output from this layer.
|
||||||
|
@ -36,6 +36,8 @@ import Data.Singletons
|
|||||||
import Data.Singletons.Prelude
|
import Data.Singletons.Prelude
|
||||||
import Data.Serialize
|
import Data.Serialize
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import Grenade.Core.Layer
|
import Grenade.Core.Layer
|
||||||
import Grenade.Core.LearningParameters
|
import Grenade.Core.LearningParameters
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
@ -48,7 +50,7 @@ import Grenade.Core.Shape
|
|||||||
--
|
--
|
||||||
-- Can be considered to be a heterogeneous list of layers which are able to
|
-- Can be considered to be a heterogeneous list of layers which are able to
|
||||||
-- transform the data shapes of the network.
|
-- transform the data shapes of the network.
|
||||||
data Network :: [*] -> [Shape] -> * where
|
data Network :: [Type] -> [Shape] -> Type where
|
||||||
NNil :: SingI i
|
NNil :: SingI i
|
||||||
=> Network '[] '[i]
|
=> Network '[] '[i]
|
||||||
|
|
||||||
@ -66,7 +68,7 @@ instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) wh
|
|||||||
-- | Gradient of a network.
|
-- | Gradient of a network.
|
||||||
--
|
--
|
||||||
-- Parameterised on the layers of the network.
|
-- Parameterised on the layers of the network.
|
||||||
data Gradients :: [*] -> * where
|
data Gradients :: [Type] -> Type where
|
||||||
GNil :: Gradients '[]
|
GNil :: Gradients '[]
|
||||||
|
|
||||||
(:/>) :: UpdateLayer x
|
(:/>) :: UpdateLayer x
|
||||||
@ -77,7 +79,7 @@ data Gradients :: [*] -> * where
|
|||||||
-- | Wegnert Tape of a network.
|
-- | Wegnert Tape of a network.
|
||||||
--
|
--
|
||||||
-- Parameterised on the layers and shapes of the network.
|
-- Parameterised on the layers and shapes of the network.
|
||||||
data Tapes :: [*] -> [Shape] -> * where
|
data Tapes :: [Type] -> [Shape] -> Type where
|
||||||
TNil :: SingI i
|
TNil :: SingI i
|
||||||
=> Tapes '[] '[i]
|
=> Tapes '[] '[i]
|
||||||
|
|
||||||
@ -152,7 +154,7 @@ applyUpdate _ NNil GNil
|
|||||||
|
|
||||||
-- | A network can easily be created by hand with (:~>), but an easy way to
|
-- | A network can easily be created by hand with (:~>), but an easy way to
|
||||||
-- initialise a random network is with the randomNetwork.
|
-- initialise a random network is with the randomNetwork.
|
||||||
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
|
class CreatableNetwork (xs :: [Type]) (ss :: [Shape]) where
|
||||||
-- | Create a network with randomly initialised weights.
|
-- | Create a network with randomly initialised weights.
|
||||||
--
|
--
|
||||||
-- Calls to this function will not compile if the type of the neural
|
-- Calls to this function will not compile if the type of the neural
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE RankNTypes #-}
|
{-# LANGUAGE RankNTypes #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Core.Shape
|
Module : Grenade.Core.Shape
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE ScopedTypeVariables #-}
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
{-# LANGUAGE StandaloneDeriving #-}
|
{-# LANGUAGE StandaloneDeriving #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Layers.Concat
|
Module : Grenade.Layers.Concat
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Layers.Convolution
|
Module : Grenade.Layers.Convolution
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Layers.Crop
|
Module : Grenade.Layers.Crop
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleInstances #-}
|
{-# LANGUAGE FlexibleInstances #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Layers.Deconvolution
|
Module : Grenade.Layers.Deconvolution
|
||||||
|
@ -23,13 +23,15 @@ import Data.Serialize
|
|||||||
|
|
||||||
import Data.Singletons
|
import Data.Singletons
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import Grenade.Core
|
import Grenade.Core
|
||||||
|
|
||||||
-- | A Merging layer.
|
-- | A Merging layer.
|
||||||
--
|
--
|
||||||
-- Similar to Concat layer, except sums the activations instead of creating a larger
|
-- Similar to Concat layer, except sums the activations instead of creating a larger
|
||||||
-- shape.
|
-- shape.
|
||||||
data Merge :: * -> * -> * where
|
data Merge :: Type -> Type -> Type where
|
||||||
Merge :: x -> y -> Merge x y
|
Merge :: x -> y -> Merge x y
|
||||||
|
|
||||||
instance (Show x, Show y) => Show (Merge x y) where
|
instance (Show x, Show y) => Show (Merge x y) where
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Core.Pad
|
Module : Grenade.Core.Pad
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Core.Pooling
|
Module : Grenade.Core.Pooling
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
{-# LANGUAGE TypeFamilies #-}
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||||
{-# LANGUAGE FlexibleContexts #-}
|
{-# LANGUAGE FlexibleContexts #-}
|
||||||
{-# LANGUAGE NoStarIsType #-}
|
|
||||||
{-# LANGUAGE UndecidableInstances #-}
|
{-# LANGUAGE UndecidableInstances #-}
|
||||||
{-|
|
{-|
|
||||||
Module : Grenade.Layers.Reshape
|
Module : Grenade.Layers.Reshape
|
||||||
|
@ -8,6 +8,8 @@ module Grenade.Recurrent.Core.Layer (
|
|||||||
, RecurrentUpdateLayer (..)
|
, RecurrentUpdateLayer (..)
|
||||||
) where
|
) where
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import Grenade.Core
|
import Grenade.Core
|
||||||
|
|
||||||
-- | Class for a recurrent layer.
|
-- | Class for a recurrent layer.
|
||||||
@ -15,11 +17,11 @@ import Grenade.Core
|
|||||||
-- of an extra recurrent data shape.
|
-- of an extra recurrent data shape.
|
||||||
class UpdateLayer x => RecurrentUpdateLayer x where
|
class UpdateLayer x => RecurrentUpdateLayer x where
|
||||||
-- | Shape of data that is passed between each subsequent run of the layer
|
-- | Shape of data that is passed between each subsequent run of the layer
|
||||||
type RecurrentShape x :: *
|
type RecurrentShape x :: Type
|
||||||
|
|
||||||
class (RecurrentUpdateLayer x, Num (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
|
class (RecurrentUpdateLayer x, Num (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
|
||||||
-- | Wengert Tape
|
-- | Wengert Tape
|
||||||
type RecTape x i o :: *
|
type RecTape x i o :: Type
|
||||||
-- | Used in training and scoring. Take the input from the previous
|
-- | Used in training and scoring. Take the input from the previous
|
||||||
-- layer, and give the output from this layer.
|
-- layer, and give the output from this layer.
|
||||||
runRecurrentForwards :: x -> RecurrentShape x -> S i -> (RecTape x i o, RecurrentShape x, S o)
|
runRecurrentForwards :: x -> RecurrentShape x -> S i -> (RecTape x i o, RecurrentShape x, S o)
|
||||||
|
@ -32,18 +32,20 @@ import Data.Singletons ( SingI )
|
|||||||
import Data.Singletons.Prelude ( Head, Last )
|
import Data.Singletons.Prelude ( Head, Last )
|
||||||
import Data.Serialize
|
import Data.Serialize
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import Grenade.Core
|
import Grenade.Core
|
||||||
import Grenade.Recurrent.Core.Layer
|
import Grenade.Recurrent.Core.Layer
|
||||||
|
|
||||||
-- | Witness type to say indicate we're building up with a normal feed
|
-- | Witness type to say indicate we're building up with a normal feed
|
||||||
-- forward layer.
|
-- forward layer.
|
||||||
data FeedForward :: * -> *
|
data FeedForward :: Type -> Type
|
||||||
-- | Witness type to say indicate we're building up with a recurrent layer.
|
-- | Witness type to say indicate we're building up with a recurrent layer.
|
||||||
data Recurrent :: * -> *
|
data Recurrent :: Type -> Type
|
||||||
|
|
||||||
-- | Type of a recurrent neural network.
|
-- | Type of a recurrent neural network.
|
||||||
--
|
--
|
||||||
-- The [*] type specifies the types of the layers.
|
-- The [Type] type specifies the types of the layers.
|
||||||
--
|
--
|
||||||
-- The [Shape] type specifies the shapes of data passed between the layers.
|
-- The [Shape] type specifies the shapes of data passed between the layers.
|
||||||
--
|
--
|
||||||
@ -52,7 +54,7 @@ data Recurrent :: * -> *
|
|||||||
--
|
--
|
||||||
-- Often, to make the definitions more concise, one will use a type alias
|
-- Often, to make the definitions more concise, one will use a type alias
|
||||||
-- for these empty data types.
|
-- for these empty data types.
|
||||||
data RecurrentNetwork :: [*] -> [Shape] -> * where
|
data RecurrentNetwork :: [Type] -> [Shape] -> Type where
|
||||||
RNil :: SingI i
|
RNil :: SingI i
|
||||||
=> RecurrentNetwork '[] '[i]
|
=> RecurrentNetwork '[] '[i]
|
||||||
|
|
||||||
@ -71,7 +73,7 @@ infixr 5 :~@>
|
|||||||
-- | Gradient of a network.
|
-- | Gradient of a network.
|
||||||
--
|
--
|
||||||
-- Parameterised on the layers of the network.
|
-- Parameterised on the layers of the network.
|
||||||
data RecurrentGradient :: [*] -> * where
|
data RecurrentGradient :: [Type] -> Type where
|
||||||
RGNil :: RecurrentGradient '[]
|
RGNil :: RecurrentGradient '[]
|
||||||
|
|
||||||
(://>) :: UpdateLayer x
|
(://>) :: UpdateLayer x
|
||||||
@ -81,7 +83,7 @@ data RecurrentGradient :: [*] -> * where
|
|||||||
|
|
||||||
-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
|
-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
|
||||||
-- Parameterised on the layers of a Network.
|
-- Parameterised on the layers of a Network.
|
||||||
data RecurrentInputs :: [*] -> * where
|
data RecurrentInputs :: [Type] -> Type where
|
||||||
RINil :: RecurrentInputs '[]
|
RINil :: RecurrentInputs '[]
|
||||||
|
|
||||||
(:~~+>) :: (UpdateLayer x, Fractional (RecurrentInputs xs))
|
(:~~+>) :: (UpdateLayer x, Fractional (RecurrentInputs xs))
|
||||||
@ -95,7 +97,7 @@ data RecurrentInputs :: [*] -> * where
|
|||||||
--
|
--
|
||||||
-- We index on the time step length as well, to ensure
|
-- We index on the time step length as well, to ensure
|
||||||
-- that that all Tape lengths are the same.
|
-- that that all Tape lengths are the same.
|
||||||
data RecurrentTape :: [*] -> [Shape] -> * where
|
data RecurrentTape :: [Type] -> [Shape] -> Type where
|
||||||
TRNil :: SingI i
|
TRNil :: SingI i
|
||||||
=> RecurrentTape '[] '[i]
|
=> RecurrentTape '[] '[i]
|
||||||
|
|
||||||
@ -204,7 +206,7 @@ instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recu
|
|||||||
|
|
||||||
-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
|
-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
|
||||||
-- recurrent network and a set of random inputs for it is with the randomRecurrent.
|
-- recurrent network and a set of random inputs for it is with the randomRecurrent.
|
||||||
class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where
|
class CreatableRecurrent (xs :: [Type]) (ss :: [Shape]) where
|
||||||
-- | Create a network of the types requested
|
-- | Create a network of the types requested
|
||||||
randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss)
|
randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss)
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ import Control.Monad.Random ( MonadRandom, getRandom )
|
|||||||
|
|
||||||
import Data.Singletons.TypeLits
|
import Data.Singletons.TypeLits
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import Numeric.LinearAlgebra.Static
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
import GHC.TypeLits
|
import GHC.TypeLits
|
||||||
@ -27,7 +29,7 @@ import Grenade.Recurrent.Core
|
|||||||
|
|
||||||
data BasicRecurrent :: Nat -- Input layer size
|
data BasicRecurrent :: Nat -- Input layer size
|
||||||
-> Nat -- Output layer size
|
-> Nat -- Output layer size
|
||||||
-> * where
|
-> Type where
|
||||||
BasicRecurrent :: ( KnownNat input
|
BasicRecurrent :: ( KnownNat input
|
||||||
, KnownNat output
|
, KnownNat output
|
||||||
, KnownNat matrixCols
|
, KnownNat matrixCols
|
||||||
@ -40,7 +42,7 @@ data BasicRecurrent :: Nat -- Input layer size
|
|||||||
|
|
||||||
data BasicRecurrent' :: Nat -- Input layer size
|
data BasicRecurrent' :: Nat -- Input layer size
|
||||||
-> Nat -- Output layer size
|
-> Nat -- Output layer size
|
||||||
-> * where
|
-> Type where
|
||||||
BasicRecurrent' :: ( KnownNat input
|
BasicRecurrent' :: ( KnownNat input
|
||||||
, KnownNat output
|
, KnownNat output
|
||||||
, KnownNat matrixCols
|
, KnownNat matrixCols
|
||||||
|
@ -27,6 +27,8 @@ import Data.Serialize
|
|||||||
import Data.Singletons
|
import Data.Singletons
|
||||||
import GHC.TypeLits
|
import GHC.TypeLits
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import Grenade.Core
|
import Grenade.Core
|
||||||
import Grenade.Recurrent.Core
|
import Grenade.Recurrent.Core
|
||||||
|
|
||||||
@ -45,7 +47,7 @@ import Numeric.LinearAlgebra.Static ( (#), split, R )
|
|||||||
--
|
--
|
||||||
-- 3D images become 3D images with more channels. The sizes must be the same, one can use Pad
|
-- 3D images become 3D images with more channels. The sizes must be the same, one can use Pad
|
||||||
-- and Crop layers to ensure this is the case.
|
-- and Crop layers to ensure this is the case.
|
||||||
data ConcatRecurrent :: Shape -> * -> Shape -> * -> * where
|
data ConcatRecurrent :: Shape -> Type -> Shape -> Type -> Type where
|
||||||
ConcatRecLeft :: x -> y -> ConcatRecurrent m (Recurrent x) n (FeedForward y)
|
ConcatRecLeft :: x -> y -> ConcatRecurrent m (Recurrent x) n (FeedForward y)
|
||||||
ConcatRecRight :: x -> y -> ConcatRecurrent m (FeedForward x) n (Recurrent y)
|
ConcatRecRight :: x -> y -> ConcatRecurrent m (FeedForward x) n (Recurrent y)
|
||||||
ConcatRecBoth :: x -> y -> ConcatRecurrent m (Recurrent x) n (Recurrent y)
|
ConcatRecBoth :: x -> y -> ConcatRecurrent m (Recurrent x) n (Recurrent y)
|
||||||
|
@ -23,6 +23,8 @@ import Data.Proxy
|
|||||||
import Data.Serialize
|
import Data.Serialize
|
||||||
import Data.Singletons.TypeLits
|
import Data.Singletons.TypeLits
|
||||||
|
|
||||||
|
import Data.Kind (Type)
|
||||||
|
|
||||||
import qualified Numeric.LinearAlgebra as LA
|
import qualified Numeric.LinearAlgebra as LA
|
||||||
import Numeric.LinearAlgebra.Static
|
import Numeric.LinearAlgebra.Static
|
||||||
|
|
||||||
@ -36,14 +38,14 @@ import Grenade.Layers.Internal.Update
|
|||||||
-- This is a Peephole formulation, so the recurrent shape is
|
-- This is a Peephole formulation, so the recurrent shape is
|
||||||
-- just the cell state, the previous output is not held or used
|
-- just the cell state, the previous output is not held or used
|
||||||
-- at all.
|
-- at all.
|
||||||
data LSTM :: Nat -> Nat -> * where
|
data LSTM :: Nat -> Nat -> Type where
|
||||||
LSTM :: ( KnownNat input
|
LSTM :: ( KnownNat input
|
||||||
, KnownNat output
|
, KnownNat output
|
||||||
) => !(LSTMWeights input output) -- Weights
|
) => !(LSTMWeights input output) -- Weights
|
||||||
-> !(LSTMWeights input output) -- Momentums
|
-> !(LSTMWeights input output) -- Momentums
|
||||||
-> LSTM input output
|
-> LSTM input output
|
||||||
|
|
||||||
data LSTMWeights :: Nat -> Nat -> * where
|
data LSTMWeights :: Nat -> Nat -> Type where
|
||||||
LSTMWeights :: ( KnownNat input
|
LSTMWeights :: ( KnownNat input
|
||||||
, KnownNat output
|
, KnownNat output
|
||||||
) => {
|
) => {
|
||||||
|
Loading…
Reference in New Issue
Block a user