diff --git a/examples/grenade-examples.cabal b/examples/grenade-examples.cabal index 9db0f73..c7c493d 100644 --- a/examples/grenade-examples.cabal +++ b/examples/grenade-examples.cabal @@ -26,7 +26,7 @@ executable feedforward , bytestring , cereal , either - , optparse-applicative >= 0.13 && < 0.17 + , optparse-applicative >= 0.13 && < 0.18 , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix @@ -42,7 +42,7 @@ executable mnist , grenade , attoparsec , either - , optparse-applicative >= 0.13 && < 0.17 + , optparse-applicative >= 0.13 && < 0.18 , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix >= 0.18 && < 0.21 @@ -61,7 +61,7 @@ executable gan-mnist , bytestring , cereal , either - , optparse-applicative >= 0.13 && < 0.17 + , optparse-applicative >= 0.13 && < 0.18 , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix >= 0.18 && < 0.21 @@ -78,7 +78,7 @@ executable recurrent , grenade , attoparsec , either - , optparse-applicative >= 0.13 && < 0.17 + , optparse-applicative >= 0.13 && < 0.18 , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix >= 0.18 && < 0.21 @@ -96,13 +96,14 @@ executable shakespeare , bytestring , cereal , either - , optparse-applicative >= 0.13 && < 0.17 + , optparse-applicative >= 0.13 && < 0.18 , text == 1.2.* , mtl >= 2.2.1 && < 2.3 , hmatrix >= 0.18 && < 0.21 , transformers , semigroups , singletons + , singletons-base , vector , MonadRandom , containers diff --git a/examples/main/shakespeare.hs b/examples/main/shakespeare.hs index bd9ff97..88571ea 100644 --- a/examples/main/shakespeare.hs +++ b/examples/main/shakespeare.hs @@ -14,22 +14,15 @@ import Data.Char ( isUpper, toUpper, toLower ) import Data.List ( foldl' ) import Data.Maybe ( fromMaybe ) -#if ! MIN_VERSION_base(4,13,0) -import Data.Semigroup ( (<>) ) -#endif import qualified Data.Vector as V import Data.Vector ( Vector ) import qualified Data.Map as M -#if ! MIN_VERSION_base(4,13,0) -import Data.Proxy ( Proxy (..) ) -#endif import qualified Data.ByteString as B import Data.Serialize -import Data.Singletons.Prelude import GHC.TypeLits import Numeric.LinearAlgebra.Static ( konst ) @@ -41,6 +34,8 @@ import Grenade.Recurrent import Grenade.Utils.OneHot import System.IO.Unsafe ( unsafeInterleaveIO ) +import Data.Proxy +import Prelude.Singletons -- The defininition for our natural language recurrent network. -- This network is able to learn and generate simple words in diff --git a/grenade.cabal b/grenade.cabal index 474bb92..9604748 100644 --- a/grenade.cabal +++ b/grenade.cabal @@ -38,7 +38,7 @@ source-repository head library build-depends: base >= 4.8 && < 5 - , bytestring == 0.10.* + , bytestring >= 0.10.0 , containers >= 0.5 && < 0.7 , cereal >= 0.5 && < 0.6 , deepseq >= 1.4 && < 1.5 @@ -48,6 +48,7 @@ library -- Versions of singletons are *tightly* coupled with the -- GHC version so its fine to drop version bounds. , singletons + , singletons-base , vector >= 0.11 && < 0.13 ghc-options: diff --git a/src/Grenade/Core/Network.hs b/src/Grenade/Core/Network.hs index f7b1f00..f1f8687 100644 --- a/src/Grenade/Core/Network.hs +++ b/src/Grenade/Core/Network.hs @@ -34,7 +34,6 @@ module Grenade.Core.Network ( import Control.Monad.Random ( MonadRandom ) import Data.Singletons -import Data.Singletons.Prelude import Data.Serialize #if MIN_VERSION_base(4,9,0) @@ -44,6 +43,7 @@ import Data.Kind (Type) import Grenade.Core.Layer import Grenade.Core.LearningParameters import Grenade.Core.Shape +import Prelude.Singletons -- | Type of a network. -- diff --git a/src/Grenade/Core/Runner.hs b/src/Grenade/Core/Runner.hs index 46149a1..4165bdf 100644 --- a/src/Grenade/Core/Runner.hs +++ b/src/Grenade/Core/Runner.hs @@ -13,11 +13,12 @@ module Grenade.Core.Runner ( , runNet ) where -import Data.Singletons.Prelude import Grenade.Core.LearningParameters import Grenade.Core.Network import Grenade.Core.Shape +import Data.Singletons +import Prelude.Singletons -- | Perform reverse automatic differentiation on the network -- for the current input and expected output. diff --git a/src/Grenade/Core/Shape.hs b/src/Grenade/Core/Shape.hs index 6a0d78c..c46d9bc 100644 --- a/src/Grenade/Core/Shape.hs +++ b/src/Grenade/Core/Shape.hs @@ -21,35 +21,20 @@ Stability : experimental module Grenade.Core.Shape ( S (..) , Shape (..) -#if MIN_VERSION_singletons(2,6,0) , SShape (..) -#else - , Sing (..) -#endif - , randomOfShape , fromStorable ) where import Control.DeepSeq (NFData (..)) import Control.Monad.Random ( MonadRandom, getRandom ) - -#if MIN_VERSION_base(4,13,0) import Data.Kind (Type) -#endif import Data.Proxy import Data.Serialize import Data.Singletons -import Data.Singletons.TypeLits import Data.Vector.Storable ( Vector ) import qualified Data.Vector.Storable as V - -#if MIN_VERSION_base(4,11,0) -import GHC.TypeLits hiding (natVal) -#else import GHC.TypeLits -#endif - import qualified Numeric.LinearAlgebra.Static as H import Numeric.LinearAlgebra.Static import qualified Numeric.LinearAlgebra as NLA @@ -99,9 +84,9 @@ deriving instance Show (S n) type instance Sing = SShape data SShape :: Shape -> Type where - D1Sing :: Sing a -> SShape ('D1 a) - D2Sing :: Sing a -> Sing b -> SShape ('D2 a b) - D3Sing :: KnownNat (a * c) => Sing a -> Sing b -> Sing c -> SShape ('D3 a b c) + D1Sing :: KnownNat a => SShape ('D1 a) + D2Sing :: (KnownNat a, KnownNat b) => SShape ('D2 a b) + D3Sing :: (KnownNat (a * c), KnownNat a, KnownNat b, KnownNat c) => SShape ('D3 a b c) #else data instance Sing (n :: Shape) where D1Sing :: Sing a -> Sing ('D1 a) @@ -110,11 +95,11 @@ data instance Sing (n :: Shape) where #endif instance KnownNat a => SingI ('D1 a) where - sing = D1Sing sing + sing = D1Sing instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where - sing = D2Sing sing sing + sing = D2Sing instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where - sing = D3Sing sing sing sing + sing = D3Sing instance SingI x => Num (S x) where (+) = n2 (+) @@ -163,13 +148,13 @@ randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x) randomOfShape = do seed :: Int <- getRandom return $ case (sing :: Sing x) of - D1Sing SNat -> + D1Sing -> S1D (randomVector seed Uniform * 2 - 1) - D2Sing SNat SNat -> + D2Sing -> S2D (uniformSample seed (-1) 1) - D3Sing SNat SNat SNat -> + D3Sing -> S3D (uniformSample seed (-1) 1) -- | Generate a shape from a Storable Vector. @@ -177,13 +162,13 @@ randomOfShape = do -- Returns Nothing if the vector is of the wrong size. fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x) fromStorable xs = case sing :: Sing x of - D1Sing SNat -> + D1Sing -> S1D <$> H.create xs - D2Sing SNat SNat -> + D2Sing -> S2D <$> mkL xs - D3Sing SNat SNat SNat -> + D3Sing -> S3D <$> mkL xs where mkL :: forall rows columns. (KnownNat rows, KnownNat columns) @@ -220,13 +205,13 @@ n2 f (S2D x) (S2D y) = S2D (f x y) n2 f (S3D x) (S3D y) = S3D (f x y) -- Helper function for creating the number instances -nk :: forall x. SingI x => Double -> S x +nk :: forall x. (SingI x) => Double -> S x nk x = case (sing :: Sing x) of - D1Sing SNat -> + D1Sing -> S1D (konst x) - D2Sing SNat SNat -> + D2Sing -> S2D (konst x) - D3Sing SNat SNat SNat -> + D3Sing -> S3D (konst x) diff --git a/src/Grenade/Layers/Convolution.hs b/src/Grenade/Layers/Convolution.hs index aa0048b..53ba12f 100644 --- a/src/Grenade/Layers/Convolution.hs +++ b/src/Grenade/Layers/Convolution.hs @@ -28,16 +28,9 @@ import Control.Monad.Random hiding ( fromList ) import Data.Maybe import Data.Proxy import Data.Serialize -import Data.Singletons.TypeLits -#if MIN_VERSION_base(4,11,0) -import GHC.TypeLits hiding (natVal) -#else import GHC.TypeLits -#endif -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import Numeric.LinearAlgebra hiding ( uniformSample, konst ) import qualified Numeric.LinearAlgebra as LA diff --git a/src/Grenade/Layers/Crop.hs b/src/Grenade/Layers/Crop.hs index 0634bca..9770957 100644 --- a/src/Grenade/Layers/Crop.hs +++ b/src/Grenade/Layers/Crop.hs @@ -21,16 +21,9 @@ module Grenade.Layers.Crop ( import Data.Maybe import Data.Proxy import Data.Serialize -import Data.Singletons.TypeLits -#if MIN_VERSION_base(4,11,0) -import GHC.TypeLits hiding (natVal) -#else import GHC.TypeLits -#endif -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import Grenade.Core import Grenade.Layers.Internal.Pad diff --git a/src/Grenade/Layers/Deconvolution.hs b/src/Grenade/Layers/Deconvolution.hs index f46d12b..e8b0ac3 100644 --- a/src/Grenade/Layers/Deconvolution.hs +++ b/src/Grenade/Layers/Deconvolution.hs @@ -32,16 +32,9 @@ import Control.Monad.Random hiding ( fromList ) import Data.Maybe import Data.Proxy import Data.Serialize -import Data.Singletons.TypeLits -#if MIN_VERSION_base(4,11,0) -import GHC.TypeLits hiding (natVal) -#else import GHC.TypeLits -#endif -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import Numeric.LinearAlgebra hiding ( uniformSample, konst ) import qualified Numeric.LinearAlgebra as LA diff --git a/src/Grenade/Layers/FullyConnected.hs b/src/Grenade/Layers/FullyConnected.hs index 7041658..f3ebdea 100644 --- a/src/Grenade/Layers/FullyConnected.hs +++ b/src/Grenade/Layers/FullyConnected.hs @@ -13,7 +13,6 @@ import Control.Monad.Random hiding (fromList) import Data.Proxy import Data.Serialize -import Data.Singletons.TypeLits import qualified Numeric.LinearAlgebra as LA import Numeric.LinearAlgebra.Static @@ -21,6 +20,7 @@ import Numeric.LinearAlgebra.Static import Grenade.Core import Grenade.Layers.Internal.Update +import GHC.TypeLits -- | A basic fully connected (or inner product) neural network layer. data FullyConnected i o = FullyConnected diff --git a/src/Grenade/Layers/Pad.hs b/src/Grenade/Layers/Pad.hs index 652a778..a6c97ac 100644 --- a/src/Grenade/Layers/Pad.hs +++ b/src/Grenade/Layers/Pad.hs @@ -21,16 +21,9 @@ module Grenade.Layers.Pad ( import Data.Maybe import Data.Proxy import Data.Serialize -import Data.Singletons.TypeLits -#if MIN_VERSION_base(4,11,0) -import GHC.TypeLits hiding (natVal) -#else import GHC.TypeLits -#endif -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import Grenade.Core import Grenade.Layers.Internal.Pad diff --git a/src/Grenade/Layers/Pooling.hs b/src/Grenade/Layers/Pooling.hs index ecfb9ee..941ddec 100644 --- a/src/Grenade/Layers/Pooling.hs +++ b/src/Grenade/Layers/Pooling.hs @@ -22,16 +22,9 @@ module Grenade.Layers.Pooling ( import Data.Maybe import Data.Proxy import Data.Serialize -import Data.Singletons.TypeLits -#if MIN_VERSION_base(4,11,0) -import GHC.TypeLits hiding (natVal) -#else import GHC.TypeLits -#endif -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import Grenade.Core import Grenade.Layers.Internal.Pooling diff --git a/src/Grenade/Layers/Reshape.hs b/src/Grenade/Layers/Reshape.hs index 8a657bc..d6979cc 100644 --- a/src/Grenade/Layers/Reshape.hs +++ b/src/Grenade/Layers/Reshape.hs @@ -17,7 +17,6 @@ module Grenade.Layers.Reshape ( import Data.Serialize -import Data.Singletons.TypeLits import GHC.TypeLits import Numeric.LinearAlgebra.Static diff --git a/src/Grenade/Recurrent/Core/Network.hs b/src/Grenade/Recurrent/Core/Network.hs index 62df741..63522e6 100644 --- a/src/Grenade/Recurrent/Core/Network.hs +++ b/src/Grenade/Recurrent/Core/Network.hs @@ -29,16 +29,13 @@ module Grenade.Recurrent.Core.Network ( import Control.Monad.Random ( MonadRandom ) -import Data.Singletons ( SingI ) -import Data.Singletons.Prelude ( Head, Last ) import Data.Serialize -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import Grenade.Core import Grenade.Recurrent.Core.Layer +import Prelude.Singletons -- | Witness type to say indicate we're building up with a normal feed -- forward layer. diff --git a/src/Grenade/Recurrent/Core/Runner.hs b/src/Grenade/Recurrent/Core/Runner.hs index 68cb585..41e6e49 100644 --- a/src/Grenade/Recurrent/Core/Runner.hs +++ b/src/Grenade/Recurrent/Core/Runner.hs @@ -17,10 +17,10 @@ module Grenade.Recurrent.Core.Runner ( ) where import Data.List ( foldl' ) -import Data.Singletons.Prelude import Grenade.Core import Grenade.Recurrent.Core.Network +import Prelude.Singletons type RecurrentGradients layers = [RecurrentGradient layers] diff --git a/src/Grenade/Recurrent/Layers/BasicRecurrent.hs b/src/Grenade/Recurrent/Layers/BasicRecurrent.hs index 53b0133..c2fc018 100644 --- a/src/Grenade/Recurrent/Layers/BasicRecurrent.hs +++ b/src/Grenade/Recurrent/Layers/BasicRecurrent.hs @@ -16,11 +16,8 @@ module Grenade.Recurrent.Layers.BasicRecurrent ( import Control.Monad.Random ( MonadRandom, getRandom ) -import Data.Singletons.TypeLits -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import Numeric.LinearAlgebra.Static diff --git a/src/Grenade/Recurrent/Layers/LSTM.hs b/src/Grenade/Recurrent/Layers/LSTM.hs index 2437148..b1dab76 100644 --- a/src/Grenade/Recurrent/Layers/LSTM.hs +++ b/src/Grenade/Recurrent/Layers/LSTM.hs @@ -21,11 +21,7 @@ import Control.Monad.Random ( MonadRandom, getRandom ) -- import Data.List ( foldl1' ) import Data.Proxy import Data.Serialize -import Data.Singletons.TypeLits - -#if MIN_VERSION_base(4,9,0) import Data.Kind (Type) -#endif import qualified Numeric.LinearAlgebra as LA import Numeric.LinearAlgebra.Static @@ -33,6 +29,7 @@ import Numeric.LinearAlgebra.Static import Grenade.Core import Grenade.Recurrent.Core import Grenade.Layers.Internal.Update +import GHC.TypeLits -- | Long Short Term Memory Recurrent unit diff --git a/src/Grenade/Utils/OneHot.hs b/src/Grenade/Utils/OneHot.hs index 259047a..a819500 100644 --- a/src/Grenade/Utils/OneHot.hs +++ b/src/Grenade/Utils/OneHot.hs @@ -22,7 +22,6 @@ import Data.Map ( Map ) import qualified Data.Map as M import Data.Proxy -import Data.Singletons.TypeLits import Data.Vector ( Vector ) import qualified Data.Vector as V @@ -31,6 +30,7 @@ import qualified Data.Vector.Storable as VS import Numeric.LinearAlgebra ( maxIndex ) import Numeric.LinearAlgebra.Devel import Numeric.LinearAlgebra.Static +import GHC.TypeLits import Grenade.Core.Shape diff --git a/stack.yaml b/stack.yaml index e147d58..0c48389 100644 --- a/stack.yaml +++ b/stack.yaml @@ -20,7 +20,7 @@ # # resolver: ./custom-snapshot.yaml # resolver: https://example.com/snapshots/2018-01-01.yaml -resolver: lts-18.28 +resolver: lts-20.18 # User packages to be built. # Various formats can be used as shown in the example below. @@ -70,3 +70,9 @@ extra-deps: # # Allow a newer minor version of GHC than the snapshot specifies # compiler-check: newer-minor + +nix: + enable: true + packages: + - blas + - lapack diff --git a/stack.yaml.lock b/stack.yaml.lock index 15158f9..3aab884 100644 --- a/stack.yaml.lock +++ b/stack.yaml.lock @@ -7,13 +7,13 @@ packages: - completed: hackage: typelits-witnesses-0.3.0.3@sha256:2d9df4ac6ff3077bfd2bf659e4b495e157723ac5b45c519762853f55df5c16db,2738 pantry-tree: - size: 469 sha256: 6a42a462f98e94933b6e9721acd912c6c6b6a4743635efd15cc5871908c816a0 + size: 469 original: hackage: typelits-witnesses-0.3.0.3 snapshots: - completed: - size: 590100 - url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/18/28.yaml - sha256: 428ec8d5ce932190d3cbe266b9eb3c175cd81e984babf876b64019e2cbe4ea68 - original: lts-18.28 + sha256: 9fa4bece7acfac1fc7930c5d6e24606004b09e80aa0e52e9f68b148201008db9 + size: 649606 + url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/18.yaml + original: lts-20.18