Cleanup imports and move examples to new project

This commit is contained in:
Huw Campbell 2017-02-20 19:16:38 +11:00
parent b855dd140d
commit e6293b8461
27 changed files with 341 additions and 216 deletions

View File

@ -1,4 +1,4 @@
Copyright (c) 2016, Huw Campbell
Copyright (c) 2016-2017, Huw Campbell
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

View File

@ -20,8 +20,12 @@ specified and initialised with random weights in a few lines of code with
```haskell
type MNIST
= Network
'[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
'[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu
, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu
, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10
, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256
, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
randomMnist :: MonadRandom m => m MNIST
randomMnist = randomNetwork

10
examples/LICENSE Normal file
View File

@ -0,0 +1,10 @@
Copyright (c) 2016-2017, Huw Campbell
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,108 @@
name: grenade-examples
version: 0.0.1
license: BSD2
license-file: LICENSE
author: Huw Campbell <huw.campbell@gmail.com>
maintainer: Huw Campbell <huw.campbell@gmail.com>
copyright: (c) 2016-2017 Huw Campbell.
synopsis: grenade-examples
category: System
cabal-version: >= 1.8
build-type: Simple
description: grenade-examples
source-repository head
type: git
location: https://github.com/HuwCampbell/grenade.git
library
executable feedforward
ghc-options: -Wall -threaded -O2
main-is: main/feedforward.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix
, transformers
, singletons
, semigroups
, MonadRandom
executable mnist
ghc-options: -Wall -threaded -O2
main-is: main/mnist.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable gan-mnist
ghc-options: -Wall -threaded -O2
main-is: main/gan-mnist.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable recurrent
ghc-options: -Wall -threaded -O2
main-is: main/recurrent.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
executable shakespeare
ghc-options: -Wall -threaded -O2
main-is: main/shakespeare.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, vector
, MonadRandom
, containers

View File

@ -77,9 +77,10 @@ trainExample rate discriminator generator realExample noiseSource
(discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
(discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
(discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake
(discriminator'fake, _) = runGradient discriminator discriminatorTapeFake guessFake
(_, push) = runGradient discriminator discriminatorTapeFake ( guessFake - 1)
(generator', _) = runGradient generator generatorTape (-push)
(generator', _) = runGradient generator generatorTape push
newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
newGenerator = applyUpdate rate generator generator'

View File

@ -34,8 +34,24 @@ import Grenade.Utils.OneHot
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
-- this network should get down to about a 1.3% error rate.
type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
--
-- /NOTE:/ This model is actually too complex for MNIST, and one should use the type given in the readme instead.
-- This one is just here to demonstrate Inception layers in use.
--
type MNIST =
Network
'[ Reshape
, Inception 28 28 1 5 5 5, Pooling 2 2 2 2, Relu
, Inception 14 14 15 5 5 5, Pooling 2 2 2 2, Relu
, Reshape
, FullyConnected 735 80, Logit
, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 28 28 1
, 'D3 28 28 15, 'D3 14 14 15, 'D3 14 14 15
, 'D3 14 14 15, 'D3 7 7 15, 'D3 7 7 15
, 'D1 735
, 'D1 80, 'D1 80
, 'D1 10, 'D1 10]
randomMnist :: MonadRandom m => m MNIST
randomMnist = randomNetwork

View File

@ -4,7 +4,7 @@ license: BSD2
license-file: LICENSE
author: Huw Campbell <huw.campbell@gmail.com>
maintainer: Huw Campbell <huw.campbell@gmail.com>
copyright: (c) 2015 Huw Campbell.
copyright: (c) 2016-2017 Huw Campbell.
synopsis: grenade
category: System
cabal-version: >= 1.8
@ -27,25 +27,26 @@ library
build-depends:
base >= 4.8 && < 5
, bytestring == 0.10.*
, containers
, deepseq
, either == 4.4.*
, cereal
, containers >= 0.5 && < 0.6
, cereal >= 0.5 && < 0.6
, deepseq >= 1.4 && < 1.5
, exceptions == 0.8.*
, hmatrix == 0.18.*
, MonadRandom
, MonadRandom >= 0.4 && < 0.6
, mtl >= 2.2.1 && < 2.3
, primitive
, primitive >= 0.6 && < 0.7
, text == 1.2.*
, transformers
, singletons >= 2.1 && < 2.3
, vector == 0.11.*
, vector >= 0.11 && < 0.13
ghc-options:
-Wall
hs-source-dirs:
src
if impl(ghc < 8.0)
ghc-options: -fno-warn-incomplete-patterns
exposed-modules:
Grenade
@ -55,19 +56,24 @@ library
Grenade.Core.Network
Grenade.Core.Runner
Grenade.Core.Shape
Grenade.Layers.Crop
Grenade.Layers
Grenade.Layers.Concat
Grenade.Layers.Convolution
Grenade.Layers.Crop
Grenade.Layers.Dropout
Grenade.Layers.Elu
Grenade.Layers.FullyConnected
Grenade.Layers.Reshape
Grenade.Layers.Inception
Grenade.Layers.Logit
Grenade.Layers.Merge
Grenade.Layers.Relu
Grenade.Layers.Elu
Grenade.Layers.Tanh
Grenade.Layers.Pad
Grenade.Layers.Pooling
Grenade.Layers.Relu
Grenade.Layers.Reshape
Grenade.Layers.Softmax
Grenade.Layers.Tanh
Grenade.Layers.Trivial
Grenade.Layers.Internal.Convolution
Grenade.Layers.Internal.Pad
@ -81,6 +87,7 @@ library
Grenade.Recurrent.Core.Network
Grenade.Recurrent.Core.Runner
Grenade.Recurrent.Layers
Grenade.Recurrent.Layers.BasicRecurrent
Grenade.Recurrent.Layers.LSTM
@ -95,98 +102,6 @@ library
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
executable feedforward
ghc-options: -Wall -threaded -O2
main-is: main/feedforward.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix
, transformers
, singletons
, semigroups
, MonadRandom
executable mnist
ghc-options: -Wall -threaded -O2
main-is: main/mnist.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable gan-mnist
ghc-options: -Wall -threaded -O2
main-is: main/gan-mnist.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable recurrent
ghc-options: -Wall -threaded -O2
main-is: main/recurrent.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
executable shakespeare
ghc-options: -Wall -threaded -O2
main-is: main/shakespeare.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, vector
, MonadRandom
, containers
test-suite test
type: exitcode-stdio-1.0

View File

@ -1,23 +1,52 @@
module Grenade (
module X
-- | This is an empty module which simply re-exports public definitions
-- for machine learning with Grenade.
-- * Exported modules
--
-- | The core types and runners for Grenade.
module Grenade.Core
-- | The neural network layer zoo
, module Grenade.Layers
-- * Overview of the library
-- $library
-- * Example usage
-- $example
) where
import Grenade.Core.LearningParameters as X
import Grenade.Core.Layer as X
import Grenade.Core.Network as X
import Grenade.Core.Runner as X
import Grenade.Core.Shape as X
import Grenade.Layers.Concat as X
import Grenade.Layers.Crop as X
import Grenade.Layers.Dropout as X
import Grenade.Layers.Pad as X
import Grenade.Layers.Pooling as X
import Grenade.Layers.Reshape as X
import Grenade.Layers.FullyConnected as X
import Grenade.Layers.Logit as X
import Grenade.Layers.Merge as X
import Grenade.Layers.Convolution as X
import Grenade.Layers.Relu as X
import Grenade.Layers.Elu as X
import Grenade.Layers.Tanh as X
import Grenade.Layers.Softmax as X
import Grenade.Core
import Grenade.Layers
{- $library
Grenade is a purely functional deep learning library.
It provides an expressive type level API for the construction
of complex neural network architectures. Backing this API is and
implementation written using BLAS and LAPACK, mostly provided by
the hmatrix library.
-}
{- $example
A few examples are provided at https://github.com/HuwCampbell/grenade
under the examples folder.
The starting place is to write your neural network type and a
function to create a random layer of that type. The following
is a simple example which runs a logistic regression.
> type MyNet = Network '[ FullyConnected 10 1, Logit ] '[ 'D1 10, 'D1 1, 'D1 1 ]
>
> randomMyNet :: MonadRandom MyNet
> randomMyNet = randomNetwork
The function `randomMyNet` witnesses the `CreatableNetwork`
constraint of the neural network, that is it ensures the network
can be built, and hence, that the architecture is sound.
-}

View File

@ -1,8 +1,13 @@
module Grenade.Core (
module X
module Grenade.Core.Layer
, module Grenade.Core.LearningParameters
, module Grenade.Core.Network
, module Grenade.Core.Runner
, module Grenade.Core.Shape
) where
import Grenade.Core.Layer as X
import Grenade.Core.LearningParameters as X
import Grenade.Core.Shape as X
import Grenade.Core.Network as X
import Grenade.Core.Layer
import Grenade.Core.LearningParameters
import Grenade.Core.Network
import Grenade.Core.Runner
import Grenade.Core.Shape

View File

@ -1,4 +1,8 @@
module Grenade.Core.LearningParameters (
-- | This module contains learning algorithm specific
-- code. Currently, this module should be consifered
-- unstable, due to issue #26.
LearningParameters (..)
) where

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-}
@ -19,10 +18,6 @@ This module defines the core data types and functions
for non-recurrent neural networks.
-}
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Core.Network (
Network (..)
, Gradients (..)
@ -47,9 +42,9 @@ import Grenade.Core.Shape
-- | Type of a network.
--
-- The [*] type specifies the types of the layers.
-- The @[*]@ type specifies the types of the layers.
--
-- The [Shape] type specifies the shapes of data passed between the layers.
-- The @[Shape]@ type specifies the shapes of data passed between the layers.
--
-- Can be considered to be a heterogeneous list of layers which are able to
-- transform the data shapes of the network.

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
@ -8,11 +7,6 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
-- Ghc 7.10 fails to recognise n2 is complete.
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
{-|
Module : Grenade.Core.Shape
Description : Core definition of the Shapes of data we understand
@ -65,17 +59,14 @@ data Shape
-- All shapes are held in contiguous memory.
-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
data S (n :: Shape) where
-- | One dimensional data
S1D :: ( KnownNat len )
=> R len
-> S ('D1 len)
-- | Two dimensional data
S2D :: ( KnownNat rows, KnownNat columns )
=> L rows columns
-> S ('D2 rows columns)
-- | Three dimensional data
S3D :: ( KnownNat rows
, KnownNat columns
, KnownNat depth

33
src/Grenade/Layers.hs Normal file
View File

@ -0,0 +1,33 @@
module Grenade.Layers (
module Grenade.Layers.Concat
, module Grenade.Layers.Convolution
, module Grenade.Layers.Crop
, module Grenade.Layers.Elu
, module Grenade.Layers.FullyConnected
, module Grenade.Layers.Inception
, module Grenade.Layers.Logit
, module Grenade.Layers.Merge
, module Grenade.Layers.Pad
, module Grenade.Layers.Pooling
, module Grenade.Layers.Reshape
, module Grenade.Layers.Relu
, module Grenade.Layers.Softmax
, module Grenade.Layers.Tanh
, module Grenade.Layers.Trivial
) where
import Grenade.Layers.Concat
import Grenade.Layers.Convolution
import Grenade.Layers.Crop
import Grenade.Layers.Elu
import Grenade.Layers.Pad
import Grenade.Layers.FullyConnected
import Grenade.Layers.Inception
import Grenade.Layers.Logit
import Grenade.Layers.Merge
import Grenade.Layers.Pooling
import Grenade.Layers.Reshape
import Grenade.Layers.Relu
import Grenade.Layers.Softmax
import Grenade.Layers.Tanh
import Grenade.Layers.Trivial

View File

@ -9,11 +9,13 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-|
Module : Grenade.Core.Network
Description : Core definition a simple neural etwork
Module : Grenade.Layers.Concat
Description : Concatenation layer
Copyright : (c) Huw Campbell, 2016-2017
License : BSD2
Stability : experimental
This module provides the concatenation layer, whic used to run two separate layers in parallel and combine their outputs.
-}
module Grenade.Layers.Concat (
Concat (..)

View File

@ -9,10 +9,13 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-|
Module : Grenade.Core.Network
Description : Core definition a simple neural etwork
Description : Inception style parallel convolutional network composition.
Copyright : (c) Huw Campbell, 2016-2017
License : BSD2
Stability : experimental
Export an Inception style type, which can be used to build up
complex multiconvolution size networks.
-}
module Grenade.Layers.Inception (
Inception
@ -25,24 +28,35 @@ import Grenade.Layers.Convolution
import Grenade.Layers.Pad
import Grenade.Layers.Concat
-- | Type of an inception layer.
--
-- It looks like a bit of a handful, but is actually pretty easy to use.
--
-- The first three type parameters are the size of the (3D) data the
-- inception layer will take. It will emit 3D data with the number of
-- channels being the sum of @chx@, @chy@, @chz@, which are the number
-- of convolution filters in the 3x3, 5x5, and 7x7 convolutions Layers
-- respectively.
--
-- The network get padded effectively before each convolution filters
-- such that the output dimension is the same x and y as the input.
type Inception rows cols channels chx chy chz
= Network '[ Concat ('D3 (rows - 2) (cols - 2) (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 (rows - 2) (cols - 2) chz) (Inception7x7 rows cols channels chz) ]
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy + chz) ]
= Network '[ Concat ('D3 rows cols (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 rows cols chz) (Inception7x7 rows cols channels chz) ]
'[ 'D3 rows cols channels, 'D3 rows cols (chx + chy + chz) ]
type InceptionS rows cols channels chx chy
= Network '[ Concat ('D3 (rows - 2) (cols - 2) chx) (Inception3x3 rows cols channels chx) ('D3 (rows - 2) (cols - 2) chy) (Inception5x5 rows cols channels chy) ]
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy) ]
= Network '[ Concat ('D3 rows cols chx) (Inception3x3 rows cols channels chx) ('D3 rows cols chy) (Inception5x5 rows cols channels chy) ]
'[ 'D3 rows cols channels, 'D3 rows cols (chx + chy) ]
type Inception3x3 rows cols channels chx
= Network '[ Convolution channels chx 3 3 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) chx ]
= Network '[ Pad 1 1 1 1, Convolution channels chx 3 3 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 rows cols chx ]
type Inception5x5 rows cols channels chx
= Network '[ Pad 1 1 1 1, Convolution channels chx 5 5 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 (rows - 2) (cols - 2) chx ]
= Network '[ Pad 2 2 2 2, Convolution channels chx 5 5 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 rows cols chx ]
type Inception7x7 rows cols channels chx
= Network '[ Pad 2 2 2 2, Convolution channels chx 7 7 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 (rows - 2) (cols - 2) chx ]
= Network '[ Pad 3 3 3 3, Convolution channels chx 7 7 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows + 6) (cols + 6) channels, 'D3 rows cols chx ]

View File

@ -9,7 +9,7 @@ module Grenade.Layers.Trivial (
import Data.Serialize
import Grenade.Core.Network
import Grenade.Core
-- | A trivial layer.
data Trivial = Trivial
@ -25,5 +25,6 @@ instance UpdateLayer Trivial where
createRandom = return Trivial
instance (a ~ b) => Layer Trivial a b where
runForwards _ = id
type Tape Trivial a b = ()
runForwards _ a = ((), a)
runBackwards _ _ y = ((), y)

View File

@ -1,9 +1,7 @@
module Grenade.Recurrent (
module X
module Grenade.Recurrent.Core
, module Grenade.Recurrent.Layers
) where
import Grenade.Recurrent.Core.Layer as X
import Grenade.Recurrent.Core.Network as X
import Grenade.Recurrent.Core.Runner as X
import Grenade.Recurrent.Layers.BasicRecurrent as X
import Grenade.Recurrent.Layers.LSTM as X
import Grenade.Recurrent.Core
import Grenade.Recurrent.Layers

View File

@ -1,6 +1,9 @@
module Grenade.Recurrent.Core (
module X
module Grenade.Recurrent.Core.Layer
, module Grenade.Recurrent.Core.Network
, module Grenade.Recurrent.Core.Runner
) where
import Grenade.Recurrent.Core.Layer as X
import Grenade.Recurrent.Core.Network as X
import Grenade.Recurrent.Core.Layer
import Grenade.Recurrent.Core.Network
import Grenade.Recurrent.Core.Runner

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
@ -10,10 +9,7 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
{-# LANGUAGE UndecidableInstances #-}
module Grenade.Recurrent.Core.Network (
Recurrent

View File

@ -3,16 +3,11 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Recurrent.Core.Runner (
trainRecurrent
, runRecurrent

View File

@ -0,0 +1,7 @@
module Grenade.Recurrent.Layers (
module Grenade.Recurrent.Layers.BasicRecurrent
, module Grenade.Recurrent.Layers.LSTM
) where
import Grenade.Recurrent.Layers.BasicRecurrent
import Grenade.Recurrent.Layers.LSTM

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
@ -8,10 +7,6 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
-- GHC 7.10 doesn't see recurrent run functions as total.
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Recurrent.Layers.BasicRecurrent (
BasicRecurrent (..)
, randomBasicRecurrent

View File

@ -1,5 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
@ -11,11 +10,6 @@
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- GHC 7.10 doesn't see recurrent run functions as total.
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Recurrent.Layers.LSTM (
LSTM (..)
, LSTMWeights (..)

View File

@ -30,6 +30,15 @@ prop_pad_crop =
(_ , grad) = runBackwards net tapes d
in d ~~~ res .&&. grad ~~~ d
prop_pad_crop_2d :: Property
prop_pad_crop_2d =
let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ]
net = Pad :~> Crop :~> NNil
in gamble genOfShape $ \(d :: S ('D2 7 9)) ->
let (tapes, res) = runForwards net d
(_ , grad) = runBackwards net tapes d
in d ~~~ res .&&. grad ~~~ d
(~~~) :: S x -> S x -> Bool
(S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001
(S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001