diff --git a/src/Grenade/Core/Shape.hs b/src/Grenade/Core/Shape.hs index e877515..4f3177d 100644 --- a/src/Grenade/Core/Shape.hs +++ b/src/Grenade/Core/Shape.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} @@ -35,7 +36,11 @@ import Data.Singletons.TypeLits import Data.Vector.Storable ( Vector ) import qualified Data.Vector.Storable as V +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (natVal) +#else import GHC.TypeLits +#endif import qualified Numeric.LinearAlgebra.Static as H import Numeric.LinearAlgebra.Static diff --git a/src/Grenade/Layers/Convolution.hs b/src/Grenade/Layers/Convolution.hs index 1ee2df0..ad5c0f4 100644 --- a/src/Grenade/Layers/Convolution.hs +++ b/src/Grenade/Layers/Convolution.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RecordWildCards #-} @@ -29,7 +30,11 @@ import Data.Proxy import Data.Serialize import Data.Singletons.TypeLits +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (natVal) +#else import GHC.TypeLits +#endif import Numeric.LinearAlgebra hiding ( uniformSample, konst ) import qualified Numeric.LinearAlgebra as LA diff --git a/src/Grenade/Layers/Crop.hs b/src/Grenade/Layers/Crop.hs index 11d0c73..8cd6df0 100644 --- a/src/Grenade/Layers/Crop.hs +++ b/src/Grenade/Layers/Crop.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GADTs #-} @@ -20,7 +21,12 @@ import Data.Maybe import Data.Proxy import Data.Serialize import Data.Singletons.TypeLits + +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (natVal) +#else import GHC.TypeLits +#endif import Grenade.Core import Grenade.Layers.Internal.Pad diff --git a/src/Grenade/Layers/Deconvolution.hs b/src/Grenade/Layers/Deconvolution.hs index f364388..a76517a 100644 --- a/src/Grenade/Layers/Deconvolution.hs +++ b/src/Grenade/Layers/Deconvolution.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RecordWildCards #-} @@ -33,7 +34,11 @@ import Data.Proxy import Data.Serialize import Data.Singletons.TypeLits +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (natVal) +#else import GHC.TypeLits +#endif import Numeric.LinearAlgebra hiding ( uniformSample, konst ) import qualified Numeric.LinearAlgebra as LA diff --git a/src/Grenade/Layers/Pad.hs b/src/Grenade/Layers/Pad.hs index e9600e1..eec06a1 100644 --- a/src/Grenade/Layers/Pad.hs +++ b/src/Grenade/Layers/Pad.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GADTs #-} @@ -20,7 +21,12 @@ import Data.Maybe import Data.Proxy import Data.Serialize import Data.Singletons.TypeLits + +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (natVal) +#else import GHC.TypeLits +#endif import Grenade.Core import Grenade.Layers.Internal.Pad diff --git a/src/Grenade/Layers/Pooling.hs b/src/Grenade/Layers/Pooling.hs index b33d188..6161d43 100644 --- a/src/Grenade/Layers/Pooling.hs +++ b/src/Grenade/Layers/Pooling.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -21,7 +22,12 @@ import Data.Maybe import Data.Proxy import Data.Serialize import Data.Singletons.TypeLits + +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (natVal) +#else import GHC.TypeLits +#endif import Grenade.Core import Grenade.Layers.Internal.Pooling diff --git a/test/Test/Grenade/Network.hs b/test/Test/Grenade/Network.hs index 0ee0318..83afed9 100644 --- a/test/Test/Grenade/Network.hs +++ b/test/Test/Grenade/Network.hs @@ -23,8 +23,6 @@ import Data.Singletons import Data.Singletons.Prelude.List import Data.Singletons.TypeLits --- import Data.Type.Equality - import Hedgehog import qualified Hedgehog.Gen as Gen import Hedgehog.Internal.Source @@ -32,7 +30,12 @@ import Hedgehog.Internal.Property ( failWith ) import Grenade +#if MIN_VERSION_base(4,11,0) +import GHC.TypeLits hiding (natVal) +#else import GHC.TypeLits +#endif + import GHC.TypeLits.Witnesses import Test.Hedgehog.Compat import Test.Hedgehog.TypeLits @@ -74,8 +77,9 @@ genNetwork = , pure (SomeNetwork (Elu :~> rest :: Network ( Elu ': layers ) ( h ': h ': hs ))) , pure (SomeNetwork (Softmax :~> rest :: Network ( Softmax ': layers ) ( h ': h ': hs ))) , do -- Reshape to two dimensions - let divisors n = 1 : [x | x <- [2..(n-1)], n `rem` x == 0] - let len = natVal l + let divisors :: Integer -> [Integer] + divisors n = 1 : [x | x <- [2..(n-1)], n `rem` x == 0] + let len = fromIntegral $ natVal l rs <- Gen.element $ divisors len let cs = len `quot` rs case ( someNatVal rs, someNatVal cs, someNatVal len ) of @@ -96,8 +100,8 @@ genNetwork = , do -- Build a convolution layer with one filter output -- Figure out some kernel sizes which work for this layer -- There must be a better way than this... - let output_r = natVal r - let output_c = natVal c + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c let ok extent kernel = [stride | stride <- [ 1 .. extent ], (extent - kernel) `mod` stride == 0] @@ -136,8 +140,8 @@ genNetwork = , do -- Build a convolution layer with one filter output -- Figure out some kernel sizes which work for this layer -- There must be a better way than this... - let output_r = natVal r - let output_c = natVal c + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c let ok extent kernel = [stride | stride <- [ 1 .. extent ], (extent - kernel) `mod` stride == 0] @@ -178,8 +182,8 @@ genNetwork = pure (SomeNetwork (conv :~> rest :: Network ( Convolution channels 1 kernelRows kernelCols strideRows strideCols ': layers ) ( ('D3 inRows inCols channels) ': h ': hs ))) _ -> Gen.discard -- Can't occur , do -- Build a Pooling layer - let output_r = natVal r - let output_c = natVal c + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c let ok extent kernel = [stride | stride <- [ 1 .. extent ], (extent - kernel) `mod` stride == 0] @@ -215,8 +219,8 @@ genNetwork = pure (SomeNetwork (Pooling :~> rest :: Network ( Pooling kernelRows kernelCols strideRows strideCols ': layers ) ( ('D2 inRows inCols) ': h ': hs ))) _ -> Gen.discard -- Can't occur , do -- Build a Pad layer - let output_r = natVal r - let output_c = natVal c + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c pad_left <- choose 0 (output_r - 1) pad_right <- choose 0 (output_r - 1 - pad_left) @@ -242,8 +246,8 @@ genNetwork = pure (SomeNetwork (Pad :~> rest :: Network ( Pad padLeft padTop padRight padBottom ': layers ) ( ('D2 inputRows inputColumns) ': h ': hs ))) _ -> Gen.discard -- Can't occur , do -- Build a Crop layer - let output_r = natVal r - let output_c = natVal c + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c crop_left <- choose 0 10 crop_right <- choose 0 10 @@ -275,9 +279,9 @@ genNetwork = , do -- Build a convolution layer with one filter output -- Figure out some kernel sizes which work for this layer -- There must be a better way than this... - let output_r = natVal r - let output_c = natVal c - let output_f = natVal f + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c + let output_f = fromIntegral $ natVal f let ok extent kernel = [stride | stride <- [ 1 .. extent ], (extent - kernel) `mod` stride == 0] @@ -318,9 +322,9 @@ genNetwork = pure (SomeNetwork (conv :~> rest :: Network ( Convolution channels filters kernelRows kernelCols strideRows strideCols ': layers ) ( ('D3 inRows inCols channels) ': h ': hs ))) _ -> Gen.discard -- Can't occur , do -- Build a Pooling layer - let output_r = natVal r - let output_c = natVal c - let output_f = natVal f + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c + let output_f = fromIntegral $ natVal f let ok extent kernel = [stride | stride <- [ 1 .. extent ], (extent - kernel) `mod` stride == 0] @@ -359,9 +363,9 @@ genNetwork = pure (SomeNetwork (Pooling :~> rest :: Network ( Pooling kernelRows kernelCols strideRows strideCols ': layers ) ( ('D3 inRows inCols filters) ': h ': hs ))) _ -> Gen.discard -- Can't occur , do -- Build a Pad layer - let output_r = natVal r - let output_c = natVal c - let output_f = natVal f + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c + let output_f = fromIntegral $ natVal f pad_left <- choose 0 (output_r - 1) pad_right <- choose 0 (output_r - 1 - pad_left) @@ -389,9 +393,9 @@ genNetwork = pure (SomeNetwork (Pad :~> rest :: Network ( Pad padLeft padTop padRight padBottom ': layers ) ( ('D3 inputRows inputColumns filters) ': h ': hs ))) _ -> Gen.discard -- Can't occur , do -- Build a Crop layer - let output_r = natVal r - let output_c = natVal c - let output_f = natVal f + let output_r = fromIntegral $ natVal r + let output_c = fromIntegral $ natVal c + let output_f = fromIntegral $ natVal f crop_left <- choose 0 10 crop_right <- choose 0 10