mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-21 21:59:30 +03:00
GHC 8.4 compatibility
- `-XNoStarIsType` is new in GHC 8.6 - can't use `-XCPP` to conditionally enable `LANGUAGE` pragmas (implicit module declaration) - use a conditional in the .cabal file, but this is active for all modules - need to replace all `*` with `Type` and `import Data.Kind (Type)` (works in 8.4)
This commit is contained in:
parent
0bc6bd5fad
commit
49a4280a91
@ -58,6 +58,8 @@ library
|
||||
if impl(ghc < 8.0)
|
||||
ghc-options: -fno-warn-incomplete-patterns
|
||||
|
||||
if impl(ghc >= 8.6)
|
||||
default-extensions: NoStarIsType
|
||||
|
||||
exposed-modules:
|
||||
Grenade
|
||||
|
@ -41,6 +41,8 @@ import Control.Monad.Random ( MonadRandom )
|
||||
|
||||
import Data.List ( foldl' )
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.LearningParameters
|
||||
|
||||
@ -50,7 +52,7 @@ import Grenade.Core.LearningParameters
|
||||
class UpdateLayer x where
|
||||
-- | The type for the gradient for this layer.
|
||||
-- Unit if there isn't a gradient to pass back.
|
||||
type Gradient x :: *
|
||||
type Gradient x :: Type
|
||||
|
||||
-- | Update a layer with its gradient and learning parameters
|
||||
runUpdate :: LearningParameters -> x -> Gradient x -> x
|
||||
@ -72,7 +74,7 @@ class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
|
||||
-- | The Wengert tape for this layer. Includes all that is required
|
||||
-- to generate the back propagated gradients efficiently. As a
|
||||
-- default, `S i` is fine.
|
||||
type Tape x i o :: *
|
||||
type Tape x i o :: Type
|
||||
|
||||
-- | Used in training and scoring. Take the input from the previous
|
||||
-- layer, and give the output from this layer.
|
||||
|
@ -36,6 +36,8 @@ import Data.Singletons
|
||||
import Data.Singletons.Prelude
|
||||
import Data.Serialize
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import Grenade.Core.Layer
|
||||
import Grenade.Core.LearningParameters
|
||||
import Grenade.Core.Shape
|
||||
@ -48,7 +50,7 @@ import Grenade.Core.Shape
|
||||
--
|
||||
-- Can be considered to be a heterogeneous list of layers which are able to
|
||||
-- transform the data shapes of the network.
|
||||
data Network :: [*] -> [Shape] -> * where
|
||||
data Network :: [Type] -> [Shape] -> Type where
|
||||
NNil :: SingI i
|
||||
=> Network '[] '[i]
|
||||
|
||||
@ -66,7 +68,7 @@ instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) wh
|
||||
-- | Gradient of a network.
|
||||
--
|
||||
-- Parameterised on the layers of the network.
|
||||
data Gradients :: [*] -> * where
|
||||
data Gradients :: [Type] -> Type where
|
||||
GNil :: Gradients '[]
|
||||
|
||||
(:/>) :: UpdateLayer x
|
||||
@ -77,7 +79,7 @@ data Gradients :: [*] -> * where
|
||||
-- | Wegnert Tape of a network.
|
||||
--
|
||||
-- Parameterised on the layers and shapes of the network.
|
||||
data Tapes :: [*] -> [Shape] -> * where
|
||||
data Tapes :: [Type] -> [Shape] -> Type where
|
||||
TNil :: SingI i
|
||||
=> Tapes '[] '[i]
|
||||
|
||||
@ -152,7 +154,7 @@ applyUpdate _ NNil GNil
|
||||
|
||||
-- | A network can easily be created by hand with (:~>), but an easy way to
|
||||
-- initialise a random network is with the randomNetwork.
|
||||
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
|
||||
class CreatableNetwork (xs :: [Type]) (ss :: [Shape]) where
|
||||
-- | Create a network with randomly initialised weights.
|
||||
--
|
||||
-- Calls to this function will not compile if the type of the neural
|
||||
|
@ -8,7 +8,6 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Core.Shape
|
||||
|
@ -8,7 +8,6 @@
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Layers.Concat
|
||||
|
@ -8,7 +8,6 @@
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Layers.Convolution
|
||||
|
@ -6,7 +6,6 @@
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Layers.Crop
|
||||
|
@ -8,7 +8,6 @@
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Layers.Deconvolution
|
||||
|
@ -23,13 +23,15 @@ import Data.Serialize
|
||||
|
||||
import Data.Singletons
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import Grenade.Core
|
||||
|
||||
-- | A Merging layer.
|
||||
--
|
||||
-- Similar to Concat layer, except sums the activations instead of creating a larger
|
||||
-- shape.
|
||||
data Merge :: * -> * -> * where
|
||||
data Merge :: Type -> Type -> Type where
|
||||
Merge :: x -> y -> Merge x y
|
||||
|
||||
instance (Show x, Show y) => Show (Merge x y) where
|
||||
|
@ -6,7 +6,6 @@
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Core.Pad
|
||||
|
@ -7,7 +7,6 @@
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Core.Pooling
|
||||
|
@ -3,7 +3,6 @@
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE NoStarIsType #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
{-|
|
||||
Module : Grenade.Layers.Reshape
|
||||
|
@ -8,6 +8,8 @@ module Grenade.Recurrent.Core.Layer (
|
||||
, RecurrentUpdateLayer (..)
|
||||
) where
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import Grenade.Core
|
||||
|
||||
-- | Class for a recurrent layer.
|
||||
@ -15,11 +17,11 @@ import Grenade.Core
|
||||
-- of an extra recurrent data shape.
|
||||
class UpdateLayer x => RecurrentUpdateLayer x where
|
||||
-- | Shape of data that is passed between each subsequent run of the layer
|
||||
type RecurrentShape x :: *
|
||||
type RecurrentShape x :: Type
|
||||
|
||||
class (RecurrentUpdateLayer x, Num (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
|
||||
-- | Wengert Tape
|
||||
type RecTape x i o :: *
|
||||
type RecTape x i o :: Type
|
||||
-- | Used in training and scoring. Take the input from the previous
|
||||
-- layer, and give the output from this layer.
|
||||
runRecurrentForwards :: x -> RecurrentShape x -> S i -> (RecTape x i o, RecurrentShape x, S o)
|
||||
|
@ -32,18 +32,20 @@ import Data.Singletons ( SingI )
|
||||
import Data.Singletons.Prelude ( Head, Last )
|
||||
import Data.Serialize
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import Grenade.Core
|
||||
import Grenade.Recurrent.Core.Layer
|
||||
|
||||
-- | Witness type to say indicate we're building up with a normal feed
|
||||
-- forward layer.
|
||||
data FeedForward :: * -> *
|
||||
data FeedForward :: Type -> Type
|
||||
-- | Witness type to say indicate we're building up with a recurrent layer.
|
||||
data Recurrent :: * -> *
|
||||
data Recurrent :: Type -> Type
|
||||
|
||||
-- | Type of a recurrent neural network.
|
||||
--
|
||||
-- The [*] type specifies the types of the layers.
|
||||
-- The [Type] type specifies the types of the layers.
|
||||
--
|
||||
-- The [Shape] type specifies the shapes of data passed between the layers.
|
||||
--
|
||||
@ -52,7 +54,7 @@ data Recurrent :: * -> *
|
||||
--
|
||||
-- Often, to make the definitions more concise, one will use a type alias
|
||||
-- for these empty data types.
|
||||
data RecurrentNetwork :: [*] -> [Shape] -> * where
|
||||
data RecurrentNetwork :: [Type] -> [Shape] -> Type where
|
||||
RNil :: SingI i
|
||||
=> RecurrentNetwork '[] '[i]
|
||||
|
||||
@ -71,7 +73,7 @@ infixr 5 :~@>
|
||||
-- | Gradient of a network.
|
||||
--
|
||||
-- Parameterised on the layers of the network.
|
||||
data RecurrentGradient :: [*] -> * where
|
||||
data RecurrentGradient :: [Type] -> Type where
|
||||
RGNil :: RecurrentGradient '[]
|
||||
|
||||
(://>) :: UpdateLayer x
|
||||
@ -81,7 +83,7 @@ data RecurrentGradient :: [*] -> * where
|
||||
|
||||
-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
|
||||
-- Parameterised on the layers of a Network.
|
||||
data RecurrentInputs :: [*] -> * where
|
||||
data RecurrentInputs :: [Type] -> Type where
|
||||
RINil :: RecurrentInputs '[]
|
||||
|
||||
(:~~+>) :: (UpdateLayer x, Fractional (RecurrentInputs xs))
|
||||
@ -95,7 +97,7 @@ data RecurrentInputs :: [*] -> * where
|
||||
--
|
||||
-- We index on the time step length as well, to ensure
|
||||
-- that that all Tape lengths are the same.
|
||||
data RecurrentTape :: [*] -> [Shape] -> * where
|
||||
data RecurrentTape :: [Type] -> [Shape] -> Type where
|
||||
TRNil :: SingI i
|
||||
=> RecurrentTape '[] '[i]
|
||||
|
||||
@ -204,7 +206,7 @@ instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recu
|
||||
|
||||
-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
|
||||
-- recurrent network and a set of random inputs for it is with the randomRecurrent.
|
||||
class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where
|
||||
class CreatableRecurrent (xs :: [Type]) (ss :: [Shape]) where
|
||||
-- | Create a network of the types requested
|
||||
randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss)
|
||||
|
||||
|
@ -18,6 +18,8 @@ import Control.Monad.Random ( MonadRandom, getRandom )
|
||||
|
||||
import Data.Singletons.TypeLits
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
import GHC.TypeLits
|
||||
@ -27,7 +29,7 @@ import Grenade.Recurrent.Core
|
||||
|
||||
data BasicRecurrent :: Nat -- Input layer size
|
||||
-> Nat -- Output layer size
|
||||
-> * where
|
||||
-> Type where
|
||||
BasicRecurrent :: ( KnownNat input
|
||||
, KnownNat output
|
||||
, KnownNat matrixCols
|
||||
@ -40,7 +42,7 @@ data BasicRecurrent :: Nat -- Input layer size
|
||||
|
||||
data BasicRecurrent' :: Nat -- Input layer size
|
||||
-> Nat -- Output layer size
|
||||
-> * where
|
||||
-> Type where
|
||||
BasicRecurrent' :: ( KnownNat input
|
||||
, KnownNat output
|
||||
, KnownNat matrixCols
|
||||
|
@ -27,6 +27,8 @@ import Data.Serialize
|
||||
import Data.Singletons
|
||||
import GHC.TypeLits
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import Grenade.Core
|
||||
import Grenade.Recurrent.Core
|
||||
|
||||
@ -45,7 +47,7 @@ import Numeric.LinearAlgebra.Static ( (#), split, R )
|
||||
--
|
||||
-- 3D images become 3D images with more channels. The sizes must be the same, one can use Pad
|
||||
-- and Crop layers to ensure this is the case.
|
||||
data ConcatRecurrent :: Shape -> * -> Shape -> * -> * where
|
||||
data ConcatRecurrent :: Shape -> Type -> Shape -> Type -> Type where
|
||||
ConcatRecLeft :: x -> y -> ConcatRecurrent m (Recurrent x) n (FeedForward y)
|
||||
ConcatRecRight :: x -> y -> ConcatRecurrent m (FeedForward x) n (Recurrent y)
|
||||
ConcatRecBoth :: x -> y -> ConcatRecurrent m (Recurrent x) n (Recurrent y)
|
||||
|
@ -23,6 +23,8 @@ import Data.Proxy
|
||||
import Data.Serialize
|
||||
import Data.Singletons.TypeLits
|
||||
|
||||
import Data.Kind (Type)
|
||||
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import Numeric.LinearAlgebra.Static
|
||||
|
||||
@ -36,14 +38,14 @@ import Grenade.Layers.Internal.Update
|
||||
-- This is a Peephole formulation, so the recurrent shape is
|
||||
-- just the cell state, the previous output is not held or used
|
||||
-- at all.
|
||||
data LSTM :: Nat -> Nat -> * where
|
||||
data LSTM :: Nat -> Nat -> Type where
|
||||
LSTM :: ( KnownNat input
|
||||
, KnownNat output
|
||||
) => !(LSTMWeights input output) -- Weights
|
||||
-> !(LSTMWeights input output) -- Momentums
|
||||
-> LSTM input output
|
||||
|
||||
data LSTMWeights :: Nat -> Nat -> * where
|
||||
data LSTMWeights :: Nat -> Nat -> Type where
|
||||
LSTMWeights :: ( KnownNat input
|
||||
, KnownNat output
|
||||
) => {
|
||||
|
Loading…
Reference in New Issue
Block a user