mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
fix test suite for GHCs 8.2 8.4 8.6
This commit is contained in:
parent
6c39f41e1d
commit
84f54dbe9b
@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
@ -15,6 +16,10 @@ import Data.Singletons ()
|
||||
import GHC.TypeLits
|
||||
import GHC.TypeLits.Witnesses
|
||||
|
||||
#if MIN_VERSION_base(4,9,0)
|
||||
import Data.Kind (Type)
|
||||
#endif
|
||||
|
||||
import Grenade.Core
|
||||
import Grenade.Layers.Convolution
|
||||
|
||||
@ -25,7 +30,7 @@ import Test.Hedgehog.Hmatrix
|
||||
import Test.Hedgehog.TypeLits
|
||||
import Test.Hedgehog.Compat
|
||||
|
||||
data OpaqueConvolution :: * where
|
||||
data OpaqueConvolution :: Type where
|
||||
OpaqueConvolution :: Convolution channels filters kernelRows kernelColumns strideRows strideColumns -> OpaqueConvolution
|
||||
|
||||
instance Show OpaqueConvolution where
|
||||
|
@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
@ -12,6 +13,10 @@ import Data.Singletons ()
|
||||
|
||||
import GHC.TypeLits
|
||||
|
||||
#if MIN_VERSION_base(4,9,0)
|
||||
import Data.Kind (Type)
|
||||
#endif
|
||||
|
||||
import Grenade.Core
|
||||
import Grenade.Layers.FullyConnected
|
||||
|
||||
@ -20,7 +25,7 @@ import Hedgehog
|
||||
import Test.Hedgehog.Compat
|
||||
import Test.Hedgehog.Hmatrix
|
||||
|
||||
data OpaqueFullyConnected :: * where
|
||||
data OpaqueFullyConnected :: Type where
|
||||
OpaqueFullyConnected :: (KnownNat i, KnownNat o) => FullyConnected i o -> OpaqueFullyConnected
|
||||
|
||||
instance Show OpaqueFullyConnected where
|
||||
|
@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
@ -9,6 +10,10 @@ module Test.Grenade.Layers.Pooling where
|
||||
import Data.Proxy
|
||||
import Data.Singletons ()
|
||||
|
||||
#if MIN_VERSION_base(4,9,0)
|
||||
import Data.Kind (Type)
|
||||
#endif
|
||||
|
||||
import GHC.TypeLits
|
||||
import Grenade.Layers.Pooling
|
||||
|
||||
@ -16,7 +21,7 @@ import Hedgehog
|
||||
|
||||
import Test.Hedgehog.Compat
|
||||
|
||||
data OpaquePooling :: * where
|
||||
data OpaquePooling :: Type where
|
||||
OpaquePooling :: (KnownNat kh, KnownNat kw, KnownNat sh, KnownNat sw) => Pooling kh kw sh sw -> OpaquePooling
|
||||
|
||||
instance Show OpaquePooling where
|
||||
@ -24,10 +29,10 @@ instance Show OpaquePooling where
|
||||
|
||||
genOpaquePooling :: Gen OpaquePooling
|
||||
genOpaquePooling = do
|
||||
Just kernelHeight <- someNatVal <$> choose 2 15
|
||||
Just kernelWidth <- someNatVal <$> choose 2 15
|
||||
Just strideHeight <- someNatVal <$> choose 2 15
|
||||
Just strideWidth <- someNatVal <$> choose 2 15
|
||||
~(Just kernelHeight) <- someNatVal <$> choose 2 15
|
||||
~(Just kernelWidth ) <- someNatVal <$> choose 2 15
|
||||
~(Just strideHeight) <- someNatVal <$> choose 2 15
|
||||
~(Just strideWidth ) <- someNatVal <$> choose 2 15
|
||||
|
||||
case (kernelHeight, kernelWidth, strideHeight, strideWidth) of
|
||||
(SomeNat (_ :: Proxy kh), SomeNat (_ :: Proxy kw), SomeNat (_ :: Proxy sh), SomeNat (_ :: Proxy sw)) ->
|
||||
|
@ -35,6 +35,9 @@ import GHC.TypeLits hiding (natVal)
|
||||
#else
|
||||
import GHC.TypeLits
|
||||
#endif
|
||||
#if MIN_VERSION_base(4,9,0)
|
||||
import Data.Kind (Type)
|
||||
#endif
|
||||
|
||||
import GHC.TypeLits.Witnesses
|
||||
import Test.Hedgehog.Compat
|
||||
@ -46,7 +49,7 @@ import Numeric.LinearAlgebra ( flatten )
|
||||
import Numeric.LinearAlgebra.Static ( extract, norm_Inf )
|
||||
import Unsafe.Coerce
|
||||
|
||||
data SomeNetwork :: * where
|
||||
data SomeNetwork :: Type where
|
||||
SomeNetwork :: ( SingI shapes, SingI (Head shapes), SingI (Last shapes), Show (Network layers shapes) ) => Network layers shapes -> SomeNetwork
|
||||
|
||||
instance Show SomeNetwork where
|
||||
@ -448,7 +451,7 @@ oneUp =
|
||||
D1Sing SNat ->
|
||||
let x = 0 :: S ( shape )
|
||||
in case x of
|
||||
( S1D x' ) -> do
|
||||
( S1D x' ) -> do
|
||||
let ex = extract x'
|
||||
let len = VS.length ex
|
||||
ix <- choose 0 (len - 1)
|
||||
@ -460,7 +463,7 @@ oneUp =
|
||||
D2Sing SNat SNat ->
|
||||
let x = 0 :: S ( shape )
|
||||
in case x of
|
||||
( S2D x' ) -> do
|
||||
( S2D x' ) -> do
|
||||
let ex = flatten ( extract x' )
|
||||
let len = VS.length ex
|
||||
ix <- choose 0 (len - 1)
|
||||
@ -472,7 +475,7 @@ oneUp =
|
||||
D3Sing SNat SNat SNat ->
|
||||
let x = 0 :: S ( shape )
|
||||
in case x of
|
||||
( S3D x' ) -> do
|
||||
( S3D x' ) -> do
|
||||
let ex = flatten ( extract x' )
|
||||
let len = VS.length ex
|
||||
ix <- choose 0 (len - 1)
|
||||
|
@ -24,7 +24,7 @@ import Test.Hedgehog.Compat
|
||||
|
||||
genNat :: Gen SomeNat
|
||||
genNat = do
|
||||
Just n <- someNatVal <$> choose 1 10
|
||||
~(Just n) <- someNatVal <$> choose 1 10
|
||||
return n
|
||||
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
|
Loading…
Reference in New Issue
Block a user