diff --git a/grenade.cabal b/grenade.cabal index 087ab95..ae84102 100644 --- a/grenade.cabal +++ b/grenade.cabal @@ -58,6 +58,8 @@ library if impl(ghc < 8.0) ghc-options: -fno-warn-incomplete-patterns + if impl(ghc >= 8.6) + default-extensions: NoStarIsType exposed-modules: Grenade diff --git a/src/Grenade/Core/Layer.hs b/src/Grenade/Core/Layer.hs index 892dace..7a9f2f2 100644 --- a/src/Grenade/Core/Layer.hs +++ b/src/Grenade/Core/Layer.hs @@ -41,6 +41,8 @@ import Control.Monad.Random ( MonadRandom ) import Data.List ( foldl' ) +import Data.Kind (Type) + import Grenade.Core.Shape import Grenade.Core.LearningParameters @@ -50,7 +52,7 @@ import Grenade.Core.LearningParameters class UpdateLayer x where -- | The type for the gradient for this layer. -- Unit if there isn't a gradient to pass back. - type Gradient x :: * + type Gradient x :: Type -- | Update a layer with its gradient and learning parameters runUpdate :: LearningParameters -> x -> Gradient x -> x @@ -72,7 +74,7 @@ class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where -- | The Wengert tape for this layer. Includes all that is required -- to generate the back propagated gradients efficiently. As a -- default, `S i` is fine. - type Tape x i o :: * + type Tape x i o :: Type -- | Used in training and scoring. Take the input from the previous -- layer, and give the output from this layer. diff --git a/src/Grenade/Core/Network.hs b/src/Grenade/Core/Network.hs index 2e6cedb..ec1a2a2 100644 --- a/src/Grenade/Core/Network.hs +++ b/src/Grenade/Core/Network.hs @@ -36,6 +36,8 @@ import Data.Singletons import Data.Singletons.Prelude import Data.Serialize +import Data.Kind (Type) + import Grenade.Core.Layer import Grenade.Core.LearningParameters import Grenade.Core.Shape @@ -48,7 +50,7 @@ import Grenade.Core.Shape -- -- Can be considered to be a heterogeneous list of layers which are able to -- transform the data shapes of the network. -data Network :: [*] -> [Shape] -> * where +data Network :: [Type] -> [Shape] -> Type where NNil :: SingI i => Network '[] '[i] @@ -66,7 +68,7 @@ instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) wh -- | Gradient of a network. -- -- Parameterised on the layers of the network. -data Gradients :: [*] -> * where +data Gradients :: [Type] -> Type where GNil :: Gradients '[] (:/>) :: UpdateLayer x @@ -77,7 +79,7 @@ data Gradients :: [*] -> * where -- | Wegnert Tape of a network. -- -- Parameterised on the layers and shapes of the network. -data Tapes :: [*] -> [Shape] -> * where +data Tapes :: [Type] -> [Shape] -> Type where TNil :: SingI i => Tapes '[] '[i] @@ -152,7 +154,7 @@ applyUpdate _ NNil GNil -- | A network can easily be created by hand with (:~>), but an easy way to -- initialise a random network is with the randomNetwork. -class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where +class CreatableNetwork (xs :: [Type]) (ss :: [Shape]) where -- | Create a network with randomly initialised weights. -- -- Calls to this function will not compile if the type of the neural diff --git a/src/Grenade/Core/Shape.hs b/src/Grenade/Core/Shape.hs index 195ee19..fec7116 100644 --- a/src/Grenade/Core/Shape.hs +++ b/src/Grenade/Core/Shape.hs @@ -8,7 +8,6 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Core.Shape diff --git a/src/Grenade/Layers/Concat.hs b/src/Grenade/Layers/Concat.hs index 31617d7..1ea18ec 100644 --- a/src/Grenade/Layers/Concat.hs +++ b/src/Grenade/Layers/Concat.hs @@ -8,7 +8,6 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Layers.Concat diff --git a/src/Grenade/Layers/Convolution.hs b/src/Grenade/Layers/Convolution.hs index 4e53f45..13ca165 100644 --- a/src/Grenade/Layers/Convolution.hs +++ b/src/Grenade/Layers/Convolution.hs @@ -8,7 +8,6 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Layers.Convolution diff --git a/src/Grenade/Layers/Crop.hs b/src/Grenade/Layers/Crop.hs index 1400fdf..a0cf181 100644 --- a/src/Grenade/Layers/Crop.hs +++ b/src/Grenade/Layers/Crop.hs @@ -6,7 +6,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Layers.Crop diff --git a/src/Grenade/Layers/Deconvolution.hs b/src/Grenade/Layers/Deconvolution.hs index f60cda3..0f4597a 100644 --- a/src/Grenade/Layers/Deconvolution.hs +++ b/src/Grenade/Layers/Deconvolution.hs @@ -8,7 +8,6 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Layers.Deconvolution diff --git a/src/Grenade/Layers/Merge.hs b/src/Grenade/Layers/Merge.hs index 6133094..8a9c814 100644 --- a/src/Grenade/Layers/Merge.hs +++ b/src/Grenade/Layers/Merge.hs @@ -23,13 +23,15 @@ import Data.Serialize import Data.Singletons +import Data.Kind (Type) + import Grenade.Core -- | A Merging layer. -- -- Similar to Concat layer, except sums the activations instead of creating a larger -- shape. -data Merge :: * -> * -> * where +data Merge :: Type -> Type -> Type where Merge :: x -> y -> Merge x y instance (Show x, Show y) => Show (Merge x y) where diff --git a/src/Grenade/Layers/Pad.hs b/src/Grenade/Layers/Pad.hs index 577a666..788153a 100644 --- a/src/Grenade/Layers/Pad.hs +++ b/src/Grenade/Layers/Pad.hs @@ -6,7 +6,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Core.Pad diff --git a/src/Grenade/Layers/Pooling.hs b/src/Grenade/Layers/Pooling.hs index 58cf748..ea8e001 100644 --- a/src/Grenade/Layers/Pooling.hs +++ b/src/Grenade/Layers/Pooling.hs @@ -7,7 +7,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Core.Pooling diff --git a/src/Grenade/Layers/Reshape.hs b/src/Grenade/Layers/Reshape.hs index e982763..8a657bc 100644 --- a/src/Grenade/Layers/Reshape.hs +++ b/src/Grenade/Layers/Reshape.hs @@ -3,7 +3,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE NoStarIsType #-} {-# LANGUAGE UndecidableInstances #-} {-| Module : Grenade.Layers.Reshape diff --git a/src/Grenade/Recurrent/Core/Layer.hs b/src/Grenade/Recurrent/Core/Layer.hs index 491c203..2611cbc 100644 --- a/src/Grenade/Recurrent/Core/Layer.hs +++ b/src/Grenade/Recurrent/Core/Layer.hs @@ -8,6 +8,8 @@ module Grenade.Recurrent.Core.Layer ( , RecurrentUpdateLayer (..) ) where +import Data.Kind (Type) + import Grenade.Core -- | Class for a recurrent layer. @@ -15,11 +17,11 @@ import Grenade.Core -- of an extra recurrent data shape. class UpdateLayer x => RecurrentUpdateLayer x where -- | Shape of data that is passed between each subsequent run of the layer - type RecurrentShape x :: * + type RecurrentShape x :: Type class (RecurrentUpdateLayer x, Num (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where -- | Wengert Tape - type RecTape x i o :: * + type RecTape x i o :: Type -- | Used in training and scoring. Take the input from the previous -- layer, and give the output from this layer. runRecurrentForwards :: x -> RecurrentShape x -> S i -> (RecTape x i o, RecurrentShape x, S o) diff --git a/src/Grenade/Recurrent/Core/Network.hs b/src/Grenade/Recurrent/Core/Network.hs index b1dcd66..8ad0c1b 100644 --- a/src/Grenade/Recurrent/Core/Network.hs +++ b/src/Grenade/Recurrent/Core/Network.hs @@ -32,18 +32,20 @@ import Data.Singletons ( SingI ) import Data.Singletons.Prelude ( Head, Last ) import Data.Serialize +import Data.Kind (Type) + import Grenade.Core import Grenade.Recurrent.Core.Layer -- | Witness type to say indicate we're building up with a normal feed -- forward layer. -data FeedForward :: * -> * +data FeedForward :: Type -> Type -- | Witness type to say indicate we're building up with a recurrent layer. -data Recurrent :: * -> * +data Recurrent :: Type -> Type -- | Type of a recurrent neural network. -- --- The [*] type specifies the types of the layers. +-- The [Type] type specifies the types of the layers. -- -- The [Shape] type specifies the shapes of data passed between the layers. -- @@ -52,7 +54,7 @@ data Recurrent :: * -> * -- -- Often, to make the definitions more concise, one will use a type alias -- for these empty data types. -data RecurrentNetwork :: [*] -> [Shape] -> * where +data RecurrentNetwork :: [Type] -> [Shape] -> Type where RNil :: SingI i => RecurrentNetwork '[] '[i] @@ -71,7 +73,7 @@ infixr 5 :~@> -- | Gradient of a network. -- -- Parameterised on the layers of the network. -data RecurrentGradient :: [*] -> * where +data RecurrentGradient :: [Type] -> Type where RGNil :: RecurrentGradient '[] (://>) :: UpdateLayer x @@ -81,7 +83,7 @@ data RecurrentGradient :: [*] -> * where -- | Recurrent inputs (sideways shapes on an imaginary unrolled graph) -- Parameterised on the layers of a Network. -data RecurrentInputs :: [*] -> * where +data RecurrentInputs :: [Type] -> Type where RINil :: RecurrentInputs '[] (:~~+>) :: (UpdateLayer x, Fractional (RecurrentInputs xs)) @@ -95,7 +97,7 @@ data RecurrentInputs :: [*] -> * where -- -- We index on the time step length as well, to ensure -- that that all Tape lengths are the same. -data RecurrentTape :: [*] -> [Shape] -> * where +data RecurrentTape :: [Type] -> [Shape] -> Type where TRNil :: SingI i => RecurrentTape '[] '[i] @@ -204,7 +206,7 @@ instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recu -- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random -- recurrent network and a set of random inputs for it is with the randomRecurrent. -class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where +class CreatableRecurrent (xs :: [Type]) (ss :: [Shape]) where -- | Create a network of the types requested randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss) diff --git a/src/Grenade/Recurrent/Layers/BasicRecurrent.hs b/src/Grenade/Recurrent/Layers/BasicRecurrent.hs index 4629ab1..3be165e 100644 --- a/src/Grenade/Recurrent/Layers/BasicRecurrent.hs +++ b/src/Grenade/Recurrent/Layers/BasicRecurrent.hs @@ -18,6 +18,8 @@ import Control.Monad.Random ( MonadRandom, getRandom ) import Data.Singletons.TypeLits +import Data.Kind (Type) + import Numeric.LinearAlgebra.Static import GHC.TypeLits @@ -27,7 +29,7 @@ import Grenade.Recurrent.Core data BasicRecurrent :: Nat -- Input layer size -> Nat -- Output layer size - -> * where + -> Type where BasicRecurrent :: ( KnownNat input , KnownNat output , KnownNat matrixCols @@ -40,7 +42,7 @@ data BasicRecurrent :: Nat -- Input layer size data BasicRecurrent' :: Nat -- Input layer size -> Nat -- Output layer size - -> * where + -> Type where BasicRecurrent' :: ( KnownNat input , KnownNat output , KnownNat matrixCols diff --git a/src/Grenade/Recurrent/Layers/ConcatRecurrent.hs b/src/Grenade/Recurrent/Layers/ConcatRecurrent.hs index c554eb9..71e34d4 100644 --- a/src/Grenade/Recurrent/Layers/ConcatRecurrent.hs +++ b/src/Grenade/Recurrent/Layers/ConcatRecurrent.hs @@ -27,6 +27,8 @@ import Data.Serialize import Data.Singletons import GHC.TypeLits +import Data.Kind (Type) + import Grenade.Core import Grenade.Recurrent.Core @@ -45,7 +47,7 @@ import Numeric.LinearAlgebra.Static ( (#), split, R ) -- -- 3D images become 3D images with more channels. The sizes must be the same, one can use Pad -- and Crop layers to ensure this is the case. -data ConcatRecurrent :: Shape -> * -> Shape -> * -> * where +data ConcatRecurrent :: Shape -> Type -> Shape -> Type -> Type where ConcatRecLeft :: x -> y -> ConcatRecurrent m (Recurrent x) n (FeedForward y) ConcatRecRight :: x -> y -> ConcatRecurrent m (FeedForward x) n (Recurrent y) ConcatRecBoth :: x -> y -> ConcatRecurrent m (Recurrent x) n (Recurrent y) diff --git a/src/Grenade/Recurrent/Layers/LSTM.hs b/src/Grenade/Recurrent/Layers/LSTM.hs index a3d6993..5f9b4bc 100644 --- a/src/Grenade/Recurrent/Layers/LSTM.hs +++ b/src/Grenade/Recurrent/Layers/LSTM.hs @@ -23,6 +23,8 @@ import Data.Proxy import Data.Serialize import Data.Singletons.TypeLits +import Data.Kind (Type) + import qualified Numeric.LinearAlgebra as LA import Numeric.LinearAlgebra.Static @@ -36,14 +38,14 @@ import Grenade.Layers.Internal.Update -- This is a Peephole formulation, so the recurrent shape is -- just the cell state, the previous output is not held or used -- at all. -data LSTM :: Nat -> Nat -> * where +data LSTM :: Nat -> Nat -> Type where LSTM :: ( KnownNat input , KnownNat output ) => !(LSTMWeights input output) -- Weights -> !(LSTMWeights input output) -- Momentums -> LSTM input output -data LSTMWeights :: Nat -> Nat -> * where +data LSTMWeights :: Nat -> Nat -> Type where LSTMWeights :: ( KnownNat input , KnownNat output ) => {