Cleanup imports and move examples to new project

This commit is contained in:
Huw Campbell 2017-02-20 19:16:38 +11:00
parent b855dd140d
commit e6293b8461
27 changed files with 341 additions and 216 deletions

View File

@ -1,4 +1,4 @@
Copyright (c) 2016, Huw Campbell Copyright (c) 2016-2017, Huw Campbell
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

View File

@ -20,8 +20,12 @@ specified and initialised with random weights in a few lines of code with
```haskell ```haskell
type MNIST type MNIST
= Network = Network
'[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit] '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10] , Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu
, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10
, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256
, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
randomMnist :: MonadRandom m => m MNIST randomMnist :: MonadRandom m => m MNIST
randomMnist = randomNetwork randomMnist = randomNetwork

10
examples/LICENSE Normal file
View File

@ -0,0 +1,10 @@
Copyright (c) 2016-2017, Huw Campbell
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,108 @@
name: grenade-examples
version: 0.0.1
license: BSD2
license-file: LICENSE
author: Huw Campbell <huw.campbell@gmail.com>
maintainer: Huw Campbell <huw.campbell@gmail.com>
copyright: (c) 2016-2017 Huw Campbell.
synopsis: grenade-examples
category: System
cabal-version: >= 1.8
build-type: Simple
description: grenade-examples
source-repository head
type: git
location: https://github.com/HuwCampbell/grenade.git
library
executable feedforward
ghc-options: -Wall -threaded -O2
main-is: main/feedforward.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix
, transformers
, singletons
, semigroups
, MonadRandom
executable mnist
ghc-options: -Wall -threaded -O2
main-is: main/mnist.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable gan-mnist
ghc-options: -Wall -threaded -O2
main-is: main/gan-mnist.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable recurrent
ghc-options: -Wall -threaded -O2
main-is: main/recurrent.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
executable shakespeare
ghc-options: -Wall -threaded -O2
main-is: main/shakespeare.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, vector
, MonadRandom
, containers

View File

@ -77,9 +77,10 @@ trainExample rate discriminator generator realExample noiseSource
(discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample (discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
(discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 ) (discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
(discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake (discriminator'fake, _) = runGradient discriminator discriminatorTapeFake guessFake
(_, push) = runGradient discriminator discriminatorTapeFake ( guessFake - 1)
(generator', _) = runGradient generator generatorTape (-push) (generator', _) = runGradient generator generatorTape push
newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ] newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
newGenerator = applyUpdate rate generator generator' newGenerator = applyUpdate rate generator generator'

View File

@ -34,8 +34,24 @@ import Grenade.Utils.OneHot
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations, -- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
-- this network should get down to about a 1.3% error rate. -- this network should get down to about a 1.3% error rate.
type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit] --
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10] -- /NOTE:/ This model is actually too complex for MNIST, and one should use the type given in the readme instead.
-- This one is just here to demonstrate Inception layers in use.
--
type MNIST =
Network
'[ Reshape
, Inception 28 28 1 5 5 5, Pooling 2 2 2 2, Relu
, Inception 14 14 15 5 5 5, Pooling 2 2 2 2, Relu
, Reshape
, FullyConnected 735 80, Logit
, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 28 28 1
, 'D3 28 28 15, 'D3 14 14 15, 'D3 14 14 15
, 'D3 14 14 15, 'D3 7 7 15, 'D3 7 7 15
, 'D1 735
, 'D1 80, 'D1 80
, 'D1 10, 'D1 10]
randomMnist :: MonadRandom m => m MNIST randomMnist :: MonadRandom m => m MNIST
randomMnist = randomNetwork randomMnist = randomNetwork

View File

@ -4,7 +4,7 @@ license: BSD2
license-file: LICENSE license-file: LICENSE
author: Huw Campbell <huw.campbell@gmail.com> author: Huw Campbell <huw.campbell@gmail.com>
maintainer: Huw Campbell <huw.campbell@gmail.com> maintainer: Huw Campbell <huw.campbell@gmail.com>
copyright: (c) 2015 Huw Campbell. copyright: (c) 2016-2017 Huw Campbell.
synopsis: grenade synopsis: grenade
category: System category: System
cabal-version: >= 1.8 cabal-version: >= 1.8
@ -27,25 +27,26 @@ library
build-depends: build-depends:
base >= 4.8 && < 5 base >= 4.8 && < 5
, bytestring == 0.10.* , bytestring == 0.10.*
, containers , containers >= 0.5 && < 0.6
, deepseq , cereal >= 0.5 && < 0.6
, either == 4.4.* , deepseq >= 1.4 && < 1.5
, cereal
, exceptions == 0.8.* , exceptions == 0.8.*
, hmatrix == 0.18.* , hmatrix == 0.18.*
, MonadRandom , MonadRandom >= 0.4 && < 0.6
, mtl >= 2.2.1 && < 2.3 , mtl >= 2.2.1 && < 2.3
, primitive , primitive >= 0.6 && < 0.7
, text == 1.2.* , text == 1.2.*
, transformers
, singletons >= 2.1 && < 2.3 , singletons >= 2.1 && < 2.3
, vector == 0.11.* , vector >= 0.11 && < 0.13
ghc-options: ghc-options:
-Wall -Wall
hs-source-dirs: hs-source-dirs:
src src
if impl(ghc < 8.0)
ghc-options: -fno-warn-incomplete-patterns
exposed-modules: exposed-modules:
Grenade Grenade
@ -55,19 +56,24 @@ library
Grenade.Core.Network Grenade.Core.Network
Grenade.Core.Runner Grenade.Core.Runner
Grenade.Core.Shape Grenade.Core.Shape
Grenade.Layers.Crop
Grenade.Layers
Grenade.Layers.Concat Grenade.Layers.Concat
Grenade.Layers.Convolution Grenade.Layers.Convolution
Grenade.Layers.Crop
Grenade.Layers.Dropout Grenade.Layers.Dropout
Grenade.Layers.Elu
Grenade.Layers.FullyConnected Grenade.Layers.FullyConnected
Grenade.Layers.Reshape Grenade.Layers.Inception
Grenade.Layers.Logit Grenade.Layers.Logit
Grenade.Layers.Merge Grenade.Layers.Merge
Grenade.Layers.Relu
Grenade.Layers.Elu
Grenade.Layers.Tanh
Grenade.Layers.Pad Grenade.Layers.Pad
Grenade.Layers.Pooling Grenade.Layers.Pooling
Grenade.Layers.Relu
Grenade.Layers.Reshape
Grenade.Layers.Softmax
Grenade.Layers.Tanh
Grenade.Layers.Trivial
Grenade.Layers.Internal.Convolution Grenade.Layers.Internal.Convolution
Grenade.Layers.Internal.Pad Grenade.Layers.Internal.Pad
@ -81,6 +87,7 @@ library
Grenade.Recurrent.Core.Network Grenade.Recurrent.Core.Network
Grenade.Recurrent.Core.Runner Grenade.Recurrent.Core.Runner
Grenade.Recurrent.Layers
Grenade.Recurrent.Layers.BasicRecurrent Grenade.Recurrent.Layers.BasicRecurrent
Grenade.Recurrent.Layers.LSTM Grenade.Recurrent.Layers.LSTM
@ -95,98 +102,6 @@ library
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1 cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
executable feedforward
ghc-options: -Wall -threaded -O2
main-is: main/feedforward.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix
, transformers
, singletons
, semigroups
, MonadRandom
executable mnist
ghc-options: -Wall -threaded -O2
main-is: main/mnist.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable gan-mnist
ghc-options: -Wall -threaded -O2
main-is: main/gan-mnist.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
, vector
executable recurrent
ghc-options: -Wall -threaded -O2
main-is: main/recurrent.hs
build-depends: base
, grenade
, attoparsec
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, MonadRandom
executable shakespeare
ghc-options: -Wall -threaded -O2
main-is: main/shakespeare.hs
build-depends: base
, grenade
, attoparsec
, bytestring
, cereal
, either
, optparse-applicative == 0.13.*
, text == 1.2.*
, mtl >= 2.2.1 && < 2.3
, hmatrix >= 0.18 && < 0.19
, transformers
, semigroups
, singletons
, vector
, MonadRandom
, containers
test-suite test test-suite test
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0

View File

@ -1,23 +1,52 @@
module Grenade ( module Grenade (
module X -- | This is an empty module which simply re-exports public definitions
-- for machine learning with Grenade.
-- * Exported modules
--
-- | The core types and runners for Grenade.
module Grenade.Core
-- | The neural network layer zoo
, module Grenade.Layers
-- * Overview of the library
-- $library
-- * Example usage
-- $example
) where ) where
import Grenade.Core.LearningParameters as X import Grenade.Core
import Grenade.Core.Layer as X import Grenade.Layers
import Grenade.Core.Network as X
import Grenade.Core.Runner as X {- $library
import Grenade.Core.Shape as X Grenade is a purely functional deep learning library.
import Grenade.Layers.Concat as X
import Grenade.Layers.Crop as X It provides an expressive type level API for the construction
import Grenade.Layers.Dropout as X of complex neural network architectures. Backing this API is and
import Grenade.Layers.Pad as X implementation written using BLAS and LAPACK, mostly provided by
import Grenade.Layers.Pooling as X the hmatrix library.
import Grenade.Layers.Reshape as X -}
import Grenade.Layers.FullyConnected as X
import Grenade.Layers.Logit as X {- $example
import Grenade.Layers.Merge as X A few examples are provided at https://github.com/HuwCampbell/grenade
import Grenade.Layers.Convolution as X under the examples folder.
import Grenade.Layers.Relu as X
import Grenade.Layers.Elu as X The starting place is to write your neural network type and a
import Grenade.Layers.Tanh as X function to create a random layer of that type. The following
import Grenade.Layers.Softmax as X is a simple example which runs a logistic regression.
> type MyNet = Network '[ FullyConnected 10 1, Logit ] '[ 'D1 10, 'D1 1, 'D1 1 ]
>
> randomMyNet :: MonadRandom MyNet
> randomMyNet = randomNetwork
The function `randomMyNet` witnesses the `CreatableNetwork`
constraint of the neural network, that is it ensures the network
can be built, and hence, that the architecture is sound.
-}

View File

@ -1,8 +1,13 @@
module Grenade.Core ( module Grenade.Core (
module X module Grenade.Core.Layer
, module Grenade.Core.LearningParameters
, module Grenade.Core.Network
, module Grenade.Core.Runner
, module Grenade.Core.Shape
) where ) where
import Grenade.Core.Layer as X import Grenade.Core.Layer
import Grenade.Core.LearningParameters as X import Grenade.Core.LearningParameters
import Grenade.Core.Shape as X import Grenade.Core.Network
import Grenade.Core.Network as X import Grenade.Core.Runner
import Grenade.Core.Shape

View File

@ -1,4 +1,8 @@
module Grenade.Core.LearningParameters ( module Grenade.Core.LearningParameters (
-- | This module contains learning algorithm specific
-- code. Currently, this module should be consifered
-- unstable, due to issue #26.
LearningParameters (..) LearningParameters (..)
) where ) where

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
@ -19,10 +18,6 @@ This module defines the core data types and functions
for non-recurrent neural networks. for non-recurrent neural networks.
-} -}
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Core.Network ( module Grenade.Core.Network (
Network (..) Network (..)
, Gradients (..) , Gradients (..)
@ -47,9 +42,9 @@ import Grenade.Core.Shape
-- | Type of a network. -- | Type of a network.
-- --
-- The [*] type specifies the types of the layers. -- The @[*]@ type specifies the types of the layers.
-- --
-- The [Shape] type specifies the shapes of data passed between the layers. -- The @[Shape]@ type specifies the shapes of data passed between the layers.
-- --
-- Can be considered to be a heterogeneous list of layers which are able to -- Can be considered to be a heterogeneous list of layers which are able to
-- transform the data shapes of the network. -- transform the data shapes of the network.

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-} {-# LANGUAGE KindSignatures #-}
@ -8,11 +7,6 @@
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
-- Ghc 7.10 fails to recognise n2 is complete.
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
{-| {-|
Module : Grenade.Core.Shape Module : Grenade.Core.Shape
Description : Core definition of the Shapes of data we understand Description : Core definition of the Shapes of data we understand
@ -65,17 +59,14 @@ data Shape
-- All shapes are held in contiguous memory. -- All shapes are held in contiguous memory.
-- 3D is held in a matrix (usually row oriented) which has height depth * rows. -- 3D is held in a matrix (usually row oriented) which has height depth * rows.
data S (n :: Shape) where data S (n :: Shape) where
-- | One dimensional data
S1D :: ( KnownNat len ) S1D :: ( KnownNat len )
=> R len => R len
-> S ('D1 len) -> S ('D1 len)
-- | Two dimensional data
S2D :: ( KnownNat rows, KnownNat columns ) S2D :: ( KnownNat rows, KnownNat columns )
=> L rows columns => L rows columns
-> S ('D2 rows columns) -> S ('D2 rows columns)
-- | Three dimensional data
S3D :: ( KnownNat rows S3D :: ( KnownNat rows
, KnownNat columns , KnownNat columns
, KnownNat depth , KnownNat depth

33
src/Grenade/Layers.hs Normal file
View File

@ -0,0 +1,33 @@
module Grenade.Layers (
module Grenade.Layers.Concat
, module Grenade.Layers.Convolution
, module Grenade.Layers.Crop
, module Grenade.Layers.Elu
, module Grenade.Layers.FullyConnected
, module Grenade.Layers.Inception
, module Grenade.Layers.Logit
, module Grenade.Layers.Merge
, module Grenade.Layers.Pad
, module Grenade.Layers.Pooling
, module Grenade.Layers.Reshape
, module Grenade.Layers.Relu
, module Grenade.Layers.Softmax
, module Grenade.Layers.Tanh
, module Grenade.Layers.Trivial
) where
import Grenade.Layers.Concat
import Grenade.Layers.Convolution
import Grenade.Layers.Crop
import Grenade.Layers.Elu
import Grenade.Layers.Pad
import Grenade.Layers.FullyConnected
import Grenade.Layers.Inception
import Grenade.Layers.Logit
import Grenade.Layers.Merge
import Grenade.Layers.Pooling
import Grenade.Layers.Reshape
import Grenade.Layers.Relu
import Grenade.Layers.Softmax
import Grenade.Layers.Tanh
import Grenade.Layers.Trivial

View File

@ -9,11 +9,13 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneDeriving #-}
{-| {-|
Module : Grenade.Core.Network Module : Grenade.Layers.Concat
Description : Core definition a simple neural etwork Description : Concatenation layer
Copyright : (c) Huw Campbell, 2016-2017 Copyright : (c) Huw Campbell, 2016-2017
License : BSD2 License : BSD2
Stability : experimental Stability : experimental
This module provides the concatenation layer, whic used to run two separate layers in parallel and combine their outputs.
-} -}
module Grenade.Layers.Concat ( module Grenade.Layers.Concat (
Concat (..) Concat (..)

View File

@ -9,10 +9,13 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-| {-|
Module : Grenade.Core.Network Module : Grenade.Core.Network
Description : Core definition a simple neural etwork Description : Inception style parallel convolutional network composition.
Copyright : (c) Huw Campbell, 2016-2017 Copyright : (c) Huw Campbell, 2016-2017
License : BSD2 License : BSD2
Stability : experimental Stability : experimental
Export an Inception style type, which can be used to build up
complex multiconvolution size networks.
-} -}
module Grenade.Layers.Inception ( module Grenade.Layers.Inception (
Inception Inception
@ -25,24 +28,35 @@ import Grenade.Layers.Convolution
import Grenade.Layers.Pad import Grenade.Layers.Pad
import Grenade.Layers.Concat import Grenade.Layers.Concat
-- | Type of an inception layer.
--
-- It looks like a bit of a handful, but is actually pretty easy to use.
--
-- The first three type parameters are the size of the (3D) data the
-- inception layer will take. It will emit 3D data with the number of
-- channels being the sum of @chx@, @chy@, @chz@, which are the number
-- of convolution filters in the 3x3, 5x5, and 7x7 convolutions Layers
-- respectively.
--
-- The network get padded effectively before each convolution filters
-- such that the output dimension is the same x and y as the input.
type Inception rows cols channels chx chy chz type Inception rows cols channels chx chy chz
= Network '[ Concat ('D3 (rows - 2) (cols - 2) (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 (rows - 2) (cols - 2) chz) (Inception7x7 rows cols channels chz) ] = Network '[ Concat ('D3 rows cols (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 rows cols chz) (Inception7x7 rows cols channels chz) ]
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy + chz) ] '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy + chz) ]
type InceptionS rows cols channels chx chy type InceptionS rows cols channels chx chy
= Network '[ Concat ('D3 (rows - 2) (cols - 2) chx) (Inception3x3 rows cols channels chx) ('D3 (rows - 2) (cols - 2) chy) (Inception5x5 rows cols channels chy) ] = Network '[ Concat ('D3 rows cols chx) (Inception3x3 rows cols channels chx) ('D3 rows cols chy) (Inception5x5 rows cols channels chy) ]
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy) ] '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy) ]
type Inception3x3 rows cols channels chx type Inception3x3 rows cols channels chx
= Network '[ Convolution channels chx 3 3 1 1 ] = Network '[ Pad 1 1 1 1, Convolution channels chx 3 3 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) chx ] '[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 rows cols chx ]
type Inception5x5 rows cols channels chx type Inception5x5 rows cols channels chx
= Network '[ Pad 1 1 1 1, Convolution channels chx 5 5 1 1 ] = Network '[ Pad 2 2 2 2, Convolution channels chx 5 5 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 (rows - 2) (cols - 2) chx ] '[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 rows cols chx ]
type Inception7x7 rows cols channels chx type Inception7x7 rows cols channels chx
= Network '[ Pad 2 2 2 2, Convolution channels chx 7 7 1 1 ] = Network '[ Pad 3 3 3 3, Convolution channels chx 7 7 1 1 ]
'[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 (rows - 2) (cols - 2) chx ] '[ 'D3 rows cols channels, 'D3 (rows + 6) (cols + 6) channels, 'D3 rows cols chx ]

View File

@ -9,7 +9,7 @@ module Grenade.Layers.Trivial (
import Data.Serialize import Data.Serialize
import Grenade.Core.Network import Grenade.Core
-- | A trivial layer. -- | A trivial layer.
data Trivial = Trivial data Trivial = Trivial
@ -25,5 +25,6 @@ instance UpdateLayer Trivial where
createRandom = return Trivial createRandom = return Trivial
instance (a ~ b) => Layer Trivial a b where instance (a ~ b) => Layer Trivial a b where
runForwards _ = id type Tape Trivial a b = ()
runForwards _ a = ((), a)
runBackwards _ _ y = ((), y) runBackwards _ _ y = ((), y)

View File

@ -1,9 +1,7 @@
module Grenade.Recurrent ( module Grenade.Recurrent (
module X module Grenade.Recurrent.Core
, module Grenade.Recurrent.Layers
) where ) where
import Grenade.Recurrent.Core.Layer as X import Grenade.Recurrent.Core
import Grenade.Recurrent.Core.Network as X import Grenade.Recurrent.Layers
import Grenade.Recurrent.Core.Runner as X
import Grenade.Recurrent.Layers.BasicRecurrent as X
import Grenade.Recurrent.Layers.LSTM as X

View File

@ -1,6 +1,9 @@
module Grenade.Recurrent.Core ( module Grenade.Recurrent.Core (
module X module Grenade.Recurrent.Core.Layer
, module Grenade.Recurrent.Core.Network
, module Grenade.Recurrent.Core.Runner
) where ) where
import Grenade.Recurrent.Core.Layer as X import Grenade.Recurrent.Core.Layer
import Grenade.Recurrent.Core.Network as X import Grenade.Recurrent.Core.Network
import Grenade.Recurrent.Core.Runner

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
@ -10,10 +9,7 @@
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Recurrent.Core.Network ( module Grenade.Recurrent.Core.Network (
Recurrent Recurrent

View File

@ -3,16 +3,11 @@
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RecordWildCards #-}
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Recurrent.Core.Runner ( module Grenade.Recurrent.Core.Runner (
trainRecurrent trainRecurrent
, runRecurrent , runRecurrent

View File

@ -0,0 +1,7 @@
module Grenade.Recurrent.Layers (
module Grenade.Recurrent.Layers.BasicRecurrent
, module Grenade.Recurrent.Layers.LSTM
) where
import Grenade.Recurrent.Layers.BasicRecurrent
import Grenade.Recurrent.Layers.LSTM

View File

@ -1,4 +1,3 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RecordWildCards #-}
@ -8,10 +7,6 @@
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableInstances #-}
-- GHC 7.10 doesn't see recurrent run functions as total.
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Recurrent.Layers.BasicRecurrent ( module Grenade.Recurrent.Layers.BasicRecurrent (
BasicRecurrent (..) BasicRecurrent (..)
, randomBasicRecurrent , randomBasicRecurrent

View File

@ -1,5 +1,4 @@
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-} {-# LANGUAGE RankNTypes #-}
@ -11,11 +10,6 @@
{-# LANGUAGE ViewPatterns #-} {-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
-- GHC 7.10 doesn't see recurrent run functions as total.
#if __GLASGOW_HASKELL__ < 800
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Grenade.Recurrent.Layers.LSTM ( module Grenade.Recurrent.Layers.LSTM (
LSTM (..) LSTM (..)
, LSTMWeights (..) , LSTMWeights (..)

View File

@ -30,6 +30,15 @@ prop_pad_crop =
(_ , grad) = runBackwards net tapes d (_ , grad) = runBackwards net tapes d
in d ~~~ res .&&. grad ~~~ d in d ~~~ res .&&. grad ~~~ d
prop_pad_crop_2d :: Property
prop_pad_crop_2d =
let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ]
net = Pad :~> Crop :~> NNil
in gamble genOfShape $ \(d :: S ('D2 7 9)) ->
let (tapes, res) = runForwards net d
(_ , grad) = runBackwards net tapes d
in d ~~~ res .&&. grad ~~~ d
(~~~) :: S x -> S x -> Bool (~~~) :: S x -> S x -> Bool
(S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001 (S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001
(S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001 (S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001