mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-20 19:44:54 +03:00
Cleanup imports and move examples to new project
This commit is contained in:
parent
b855dd140d
commit
e6293b8461
2
LICENSE
2
LICENSE
@ -1,4 +1,4 @@
|
||||
Copyright (c) 2016, Huw Campbell
|
||||
Copyright (c) 2016-2017, Huw Campbell
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
@ -20,8 +20,12 @@ specified and initialised with random weights in a few lines of code with
|
||||
```haskell
|
||||
type MNIST
|
||||
= Network
|
||||
'[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
|
||||
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
|
||||
'[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu
|
||||
, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu
|
||||
, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
|
||||
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10
|
||||
, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256
|
||||
, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
|
||||
|
||||
randomMnist :: MonadRandom m => m MNIST
|
||||
randomMnist = randomNetwork
|
||||
|
10
examples/LICENSE
Normal file
10
examples/LICENSE
Normal file
@ -0,0 +1,10 @@
|
||||
Copyright (c) 2016-2017, Huw Campbell
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
108
examples/grenade-examples.cabal
Normal file
108
examples/grenade-examples.cabal
Normal file
@ -0,0 +1,108 @@
|
||||
name: grenade-examples
|
||||
version: 0.0.1
|
||||
license: BSD2
|
||||
license-file: LICENSE
|
||||
author: Huw Campbell <huw.campbell@gmail.com>
|
||||
maintainer: Huw Campbell <huw.campbell@gmail.com>
|
||||
copyright: (c) 2016-2017 Huw Campbell.
|
||||
synopsis: grenade-examples
|
||||
category: System
|
||||
cabal-version: >= 1.8
|
||||
build-type: Simple
|
||||
description: grenade-examples
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
location: https://github.com/HuwCampbell/grenade.git
|
||||
|
||||
library
|
||||
|
||||
executable feedforward
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/feedforward.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, cereal
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix
|
||||
, transformers
|
||||
, singletons
|
||||
, semigroups
|
||||
, MonadRandom
|
||||
|
||||
executable mnist
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/mnist.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, MonadRandom
|
||||
, vector
|
||||
|
||||
executable gan-mnist
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/gan-mnist.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, cereal
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, MonadRandom
|
||||
, vector
|
||||
|
||||
executable recurrent
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/recurrent.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, MonadRandom
|
||||
|
||||
executable shakespeare
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/shakespeare.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, cereal
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, vector
|
||||
, MonadRandom
|
||||
, containers
|
@ -58,7 +58,7 @@ import Grenade
|
||||
import Grenade.Utils.OneHot
|
||||
|
||||
type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit]
|
||||
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
|
||||
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
|
||||
|
||||
type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape]
|
||||
'[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ]
|
||||
@ -77,9 +77,10 @@ trainExample rate discriminator generator realExample noiseSource
|
||||
(discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
|
||||
|
||||
(discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
|
||||
(discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake
|
||||
(discriminator'fake, _) = runGradient discriminator discriminatorTapeFake guessFake
|
||||
(_, push) = runGradient discriminator discriminatorTapeFake ( guessFake - 1)
|
||||
|
||||
(generator', _) = runGradient generator generatorTape (-push)
|
||||
(generator', _) = runGradient generator generatorTape push
|
||||
|
||||
newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
|
||||
newGenerator = applyUpdate rate generator generator'
|
@ -34,8 +34,24 @@ import Grenade.Utils.OneHot
|
||||
|
||||
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
|
||||
-- this network should get down to about a 1.3% error rate.
|
||||
type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
|
||||
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
|
||||
--
|
||||
-- /NOTE:/ This model is actually too complex for MNIST, and one should use the type given in the readme instead.
|
||||
-- This one is just here to demonstrate Inception layers in use.
|
||||
--
|
||||
type MNIST =
|
||||
Network
|
||||
'[ Reshape
|
||||
, Inception 28 28 1 5 5 5, Pooling 2 2 2 2, Relu
|
||||
, Inception 14 14 15 5 5 5, Pooling 2 2 2 2, Relu
|
||||
, Reshape
|
||||
, FullyConnected 735 80, Logit
|
||||
, FullyConnected 80 10, Logit]
|
||||
'[ 'D2 28 28, 'D3 28 28 1
|
||||
, 'D3 28 28 15, 'D3 14 14 15, 'D3 14 14 15
|
||||
, 'D3 14 14 15, 'D3 7 7 15, 'D3 7 7 15
|
||||
, 'D1 735
|
||||
, 'D1 80, 'D1 80
|
||||
, 'D1 10, 'D1 10]
|
||||
|
||||
randomMnist :: MonadRandom m => m MNIST
|
||||
randomMnist = randomNetwork
|
157
grenade.cabal
157
grenade.cabal
@ -4,7 +4,7 @@ license: BSD2
|
||||
license-file: LICENSE
|
||||
author: Huw Campbell <huw.campbell@gmail.com>
|
||||
maintainer: Huw Campbell <huw.campbell@gmail.com>
|
||||
copyright: (c) 2015 Huw Campbell.
|
||||
copyright: (c) 2016-2017 Huw Campbell.
|
||||
synopsis: grenade
|
||||
category: System
|
||||
cabal-version: >= 1.8
|
||||
@ -12,12 +12,12 @@ build-type: Simple
|
||||
description: grenade.
|
||||
|
||||
extra-source-files:
|
||||
cbits/im2col.h
|
||||
cbits/im2col.c
|
||||
cbits/gradient_decent.h
|
||||
cbits/gradient_decent.c
|
||||
cbits/pad.h
|
||||
cbits/pad.c
|
||||
cbits/im2col.h
|
||||
cbits/im2col.c
|
||||
cbits/gradient_decent.h
|
||||
cbits/gradient_decent.c
|
||||
cbits/pad.h
|
||||
cbits/pad.c
|
||||
|
||||
source-repository head
|
||||
type: git
|
||||
@ -27,25 +27,26 @@ library
|
||||
build-depends:
|
||||
base >= 4.8 && < 5
|
||||
, bytestring == 0.10.*
|
||||
, containers
|
||||
, deepseq
|
||||
, either == 4.4.*
|
||||
, cereal
|
||||
, containers >= 0.5 && < 0.6
|
||||
, cereal >= 0.5 && < 0.6
|
||||
, deepseq >= 1.4 && < 1.5
|
||||
, exceptions == 0.8.*
|
||||
, hmatrix == 0.18.*
|
||||
, MonadRandom
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, primitive
|
||||
, MonadRandom >= 0.4 && < 0.6
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, primitive >= 0.6 && < 0.7
|
||||
, text == 1.2.*
|
||||
, transformers
|
||||
, singletons >= 2.1 && < 2.3
|
||||
, vector == 0.11.*
|
||||
, singletons >= 2.1 && < 2.3
|
||||
, vector >= 0.11 && < 0.13
|
||||
|
||||
ghc-options:
|
||||
-Wall
|
||||
hs-source-dirs:
|
||||
src
|
||||
|
||||
if impl(ghc < 8.0)
|
||||
ghc-options: -fno-warn-incomplete-patterns
|
||||
|
||||
|
||||
exposed-modules:
|
||||
Grenade
|
||||
@ -55,19 +56,24 @@ library
|
||||
Grenade.Core.Network
|
||||
Grenade.Core.Runner
|
||||
Grenade.Core.Shape
|
||||
Grenade.Layers.Crop
|
||||
|
||||
Grenade.Layers
|
||||
Grenade.Layers.Concat
|
||||
Grenade.Layers.Convolution
|
||||
Grenade.Layers.Crop
|
||||
Grenade.Layers.Dropout
|
||||
Grenade.Layers.Elu
|
||||
Grenade.Layers.FullyConnected
|
||||
Grenade.Layers.Reshape
|
||||
Grenade.Layers.Inception
|
||||
Grenade.Layers.Logit
|
||||
Grenade.Layers.Merge
|
||||
Grenade.Layers.Relu
|
||||
Grenade.Layers.Elu
|
||||
Grenade.Layers.Tanh
|
||||
Grenade.Layers.Pad
|
||||
Grenade.Layers.Pooling
|
||||
Grenade.Layers.Relu
|
||||
Grenade.Layers.Reshape
|
||||
Grenade.Layers.Softmax
|
||||
Grenade.Layers.Tanh
|
||||
Grenade.Layers.Trivial
|
||||
|
||||
Grenade.Layers.Internal.Convolution
|
||||
Grenade.Layers.Internal.Pad
|
||||
@ -81,111 +87,20 @@ library
|
||||
Grenade.Recurrent.Core.Network
|
||||
Grenade.Recurrent.Core.Runner
|
||||
|
||||
Grenade.Recurrent.Layers
|
||||
Grenade.Recurrent.Layers.BasicRecurrent
|
||||
Grenade.Recurrent.Layers.LSTM
|
||||
|
||||
Grenade.Utils.OneHot
|
||||
|
||||
includes: cbits/im2col.h
|
||||
cbits/gradient_decent.h
|
||||
cbits/pad.h
|
||||
c-sources: cbits/im2col.c
|
||||
cbits/gradient_decent.c
|
||||
cbits/pad.c
|
||||
|
||||
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
|
||||
|
||||
executable feedforward
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/feedforward.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, cereal
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix
|
||||
, transformers
|
||||
, singletons
|
||||
, semigroups
|
||||
, MonadRandom
|
||||
|
||||
executable mnist
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/mnist.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, MonadRandom
|
||||
, vector
|
||||
|
||||
executable gan-mnist
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/gan-mnist.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, cereal
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, MonadRandom
|
||||
, vector
|
||||
|
||||
executable recurrent
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/recurrent.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, MonadRandom
|
||||
|
||||
|
||||
executable shakespeare
|
||||
ghc-options: -Wall -threaded -O2
|
||||
main-is: main/shakespeare.hs
|
||||
build-depends: base
|
||||
, grenade
|
||||
, attoparsec
|
||||
, bytestring
|
||||
, cereal
|
||||
, either
|
||||
, optparse-applicative == 0.13.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, semigroups
|
||||
, singletons
|
||||
, vector
|
||||
, MonadRandom
|
||||
, containers
|
||||
includes: cbits/im2col.h
|
||||
cbits/gradient_decent.h
|
||||
cbits/pad.h
|
||||
c-sources: cbits/im2col.c
|
||||
cbits/gradient_decent.c
|
||||
cbits/pad.c
|
||||
|
||||
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
|
||||
|
||||
test-suite test
|
||||
type: exitcode-stdio-1.0
|
||||
|
@ -1,23 +1,52 @@
|
||||
module Grenade (
|
||||
module X
|
||||
-- | This is an empty module which simply re-exports public definitions
|
||||
-- for machine learning with Grenade.
|
||||
|
||||
-- * Exported modules
|
||||
--
|
||||
-- | The core types and runners for Grenade.
|
||||
module Grenade.Core
|
||||
|
||||
-- | The neural network layer zoo
|
||||
, module Grenade.Layers
|
||||
|
||||
|
||||
-- * Overview of the library
|
||||
-- $library
|
||||
|
||||
-- * Example usage
|
||||
-- $example
|
||||
|
||||
) where
|
||||
|
||||
import Grenade.Core.LearningParameters as X
|
||||
import Grenade.Core.Layer as X
|
||||
import Grenade.Core.Network as X
|
||||
import Grenade.Core.Runner as X
|
||||
import Grenade.Core.Shape as X
|
||||
import Grenade.Layers.Concat as X
|
||||
import Grenade.Layers.Crop as X
|
||||
import Grenade.Layers.Dropout as X
|
||||
import Grenade.Layers.Pad as X
|
||||
import Grenade.Layers.Pooling as X
|
||||
import Grenade.Layers.Reshape as X
|
||||
import Grenade.Layers.FullyConnected as X
|
||||
import Grenade.Layers.Logit as X
|
||||
import Grenade.Layers.Merge as X
|
||||
import Grenade.Layers.Convolution as X
|
||||
import Grenade.Layers.Relu as X
|
||||
import Grenade.Layers.Elu as X
|
||||
import Grenade.Layers.Tanh as X
|
||||
import Grenade.Layers.Softmax as X
|
||||
import Grenade.Core
|
||||
import Grenade.Layers
|
||||
|
||||
{- $library
|
||||
Grenade is a purely functional deep learning library.
|
||||
|
||||
It provides an expressive type level API for the construction
|
||||
of complex neural network architectures. Backing this API is and
|
||||
implementation written using BLAS and LAPACK, mostly provided by
|
||||
the hmatrix library.
|
||||
-}
|
||||
|
||||
{- $example
|
||||
A few examples are provided at https://github.com/HuwCampbell/grenade
|
||||
under the examples folder.
|
||||
|
||||
The starting place is to write your neural network type and a
|
||||
function to create a random layer of that type. The following
|
||||
is a simple example which runs a logistic regression.
|
||||
|
||||
> type MyNet = Network '[ FullyConnected 10 1, Logit ] '[ 'D1 10, 'D1 1, 'D1 1 ]
|
||||
>
|
||||
> randomMyNet :: MonadRandom MyNet
|
||||
> randomMyNet = randomNetwork
|
||||
|
||||
The function `randomMyNet` witnesses the `CreatableNetwork`
|
||||
constraint of the neural network, that is it ensures the network
|
||||
can be built, and hence, that the architecture is sound.
|
||||
-}
|
||||
|
||||
|
||||
|
@ -1,8 +1,13 @@
|
||||
module Grenade.Core (
|
||||
module X
|
||||
module Grenade.Core.Layer
|
||||
, module Grenade.Core.LearningParameters
|
||||
, module Grenade.Core.Network
|
||||
, module Grenade.Core.Runner
|
||||
, module Grenade.Core.Shape
|
||||
) where
|
||||
|
||||
import Grenade.Core.Layer as X
|
||||
import Grenade.Core.LearningParameters as X
|
||||
import Grenade.Core.Shape as X
|
||||
import Grenade.Core.Network as X
|
||||
import Grenade.Core.Layer
|
||||
import Grenade.Core.LearningParameters
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Runner
|
||||
import Grenade.Core.Shape
|
||||
|
@ -1,4 +1,8 @@
|
||||
module Grenade.Core.LearningParameters (
|
||||
-- | This module contains learning algorithm specific
|
||||
-- code. Currently, this module should be consifered
|
||||
-- unstable, due to issue #26.
|
||||
|
||||
LearningParameters (..)
|
||||
) where
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
@ -19,10 +18,6 @@ This module defines the core data types and functions
|
||||
for non-recurrent neural networks.
|
||||
-}
|
||||
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
|
||||
#endif
|
||||
|
||||
module Grenade.Core.Network (
|
||||
Network (..)
|
||||
, Gradients (..)
|
||||
@ -47,9 +42,9 @@ import Grenade.Core.Shape
|
||||
|
||||
-- | Type of a network.
|
||||
--
|
||||
-- The [*] type specifies the types of the layers.
|
||||
-- The @[*]@ type specifies the types of the layers.
|
||||
--
|
||||
-- The [Shape] type specifies the shapes of data passed between the layers.
|
||||
-- The @[Shape]@ type specifies the shapes of data passed between the layers.
|
||||
--
|
||||
-- Can be considered to be a heterogeneous list of layers which are able to
|
||||
-- transform the data shapes of the network.
|
||||
|
@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE KindSignatures #-}
|
||||
@ -8,11 +7,6 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
|
||||
-- Ghc 7.10 fails to recognise n2 is complete.
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
|
||||
#endif
|
||||
{-|
|
||||
Module : Grenade.Core.Shape
|
||||
Description : Core definition of the Shapes of data we understand
|
||||
@ -65,17 +59,14 @@ data Shape
|
||||
-- All shapes are held in contiguous memory.
|
||||
-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
|
||||
data S (n :: Shape) where
|
||||
-- | One dimensional data
|
||||
S1D :: ( KnownNat len )
|
||||
=> R len
|
||||
-> S ('D1 len)
|
||||
|
||||
-- | Two dimensional data
|
||||
S2D :: ( KnownNat rows, KnownNat columns )
|
||||
=> L rows columns
|
||||
-> S ('D2 rows columns)
|
||||
|
||||
-- | Three dimensional data
|
||||
S3D :: ( KnownNat rows
|
||||
, KnownNat columns
|
||||
, KnownNat depth
|
||||
|
33
src/Grenade/Layers.hs
Normal file
33
src/Grenade/Layers.hs
Normal file
@ -0,0 +1,33 @@
|
||||
module Grenade.Layers (
|
||||
module Grenade.Layers.Concat
|
||||
, module Grenade.Layers.Convolution
|
||||
, module Grenade.Layers.Crop
|
||||
, module Grenade.Layers.Elu
|
||||
, module Grenade.Layers.FullyConnected
|
||||
, module Grenade.Layers.Inception
|
||||
, module Grenade.Layers.Logit
|
||||
, module Grenade.Layers.Merge
|
||||
, module Grenade.Layers.Pad
|
||||
, module Grenade.Layers.Pooling
|
||||
, module Grenade.Layers.Reshape
|
||||
, module Grenade.Layers.Relu
|
||||
, module Grenade.Layers.Softmax
|
||||
, module Grenade.Layers.Tanh
|
||||
, module Grenade.Layers.Trivial
|
||||
) where
|
||||
|
||||
import Grenade.Layers.Concat
|
||||
import Grenade.Layers.Convolution
|
||||
import Grenade.Layers.Crop
|
||||
import Grenade.Layers.Elu
|
||||
import Grenade.Layers.Pad
|
||||
import Grenade.Layers.FullyConnected
|
||||
import Grenade.Layers.Inception
|
||||
import Grenade.Layers.Logit
|
||||
import Grenade.Layers.Merge
|
||||
import Grenade.Layers.Pooling
|
||||
import Grenade.Layers.Reshape
|
||||
import Grenade.Layers.Relu
|
||||
import Grenade.Layers.Softmax
|
||||
import Grenade.Layers.Tanh
|
||||
import Grenade.Layers.Trivial
|
@ -9,11 +9,13 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-|
|
||||
Module : Grenade.Core.Network
|
||||
Description : Core definition a simple neural etwork
|
||||
Module : Grenade.Layers.Concat
|
||||
Description : Concatenation layer
|
||||
Copyright : (c) Huw Campbell, 2016-2017
|
||||
License : BSD2
|
||||
Stability : experimental
|
||||
|
||||
This module provides the concatenation layer, whic used to run two separate layers in parallel and combine their outputs.
|
||||
-}
|
||||
module Grenade.Layers.Concat (
|
||||
Concat (..)
|
||||
|
@ -9,10 +9,13 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-|
|
||||
Module : Grenade.Core.Network
|
||||
Description : Core definition a simple neural etwork
|
||||
Description : Inception style parallel convolutional network composition.
|
||||
Copyright : (c) Huw Campbell, 2016-2017
|
||||
License : BSD2
|
||||
Stability : experimental
|
||||
|
||||
Export an Inception style type, which can be used to build up
|
||||
complex multiconvolution size networks.
|
||||
-}
|
||||
module Grenade.Layers.Inception (
|
||||
Inception
|
||||
@ -25,24 +28,35 @@ import Grenade.Layers.Convolution
|
||||
import Grenade.Layers.Pad
|
||||
import Grenade.Layers.Concat
|
||||
|
||||
|
||||
-- | Type of an inception layer.
|
||||
--
|
||||
-- It looks like a bit of a handful, but is actually pretty easy to use.
|
||||
--
|
||||
-- The first three type parameters are the size of the (3D) data the
|
||||
-- inception layer will take. It will emit 3D data with the number of
|
||||
-- channels being the sum of @chx@, @chy@, @chz@, which are the number
|
||||
-- of convolution filters in the 3x3, 5x5, and 7x7 convolutions Layers
|
||||
-- respectively.
|
||||
--
|
||||
-- The network get padded effectively before each convolution filters
|
||||
-- such that the output dimension is the same x and y as the input.
|
||||
type Inception rows cols channels chx chy chz
|
||||
= Network '[ Concat ('D3 (rows - 2) (cols - 2) (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 (rows - 2) (cols - 2) chz) (Inception7x7 rows cols channels chz) ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy + chz) ]
|
||||
= Network '[ Concat ('D3 rows cols (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 rows cols chz) (Inception7x7 rows cols channels chz) ]
|
||||
'[ 'D3 rows cols channels, 'D3 rows cols (chx + chy + chz) ]
|
||||
|
||||
type InceptionS rows cols channels chx chy
|
||||
= Network '[ Concat ('D3 (rows - 2) (cols - 2) chx) (Inception3x3 rows cols channels chx) ('D3 (rows - 2) (cols - 2) chy) (Inception5x5 rows cols channels chy) ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy) ]
|
||||
= Network '[ Concat ('D3 rows cols chx) (Inception3x3 rows cols channels chx) ('D3 rows cols chy) (Inception5x5 rows cols channels chy) ]
|
||||
'[ 'D3 rows cols channels, 'D3 rows cols (chx + chy) ]
|
||||
|
||||
type Inception3x3 rows cols channels chx
|
||||
= Network '[ Convolution channels chx 3 3 1 1 ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) chx ]
|
||||
= Network '[ Pad 1 1 1 1, Convolution channels chx 3 3 1 1 ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 rows cols chx ]
|
||||
|
||||
type Inception5x5 rows cols channels chx
|
||||
= Network '[ Pad 1 1 1 1, Convolution channels chx 5 5 1 1 ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 (rows - 2) (cols - 2) chx ]
|
||||
= Network '[ Pad 2 2 2 2, Convolution channels chx 5 5 1 1 ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 rows cols chx ]
|
||||
|
||||
type Inception7x7 rows cols channels chx
|
||||
= Network '[ Pad 2 2 2 2, Convolution channels chx 7 7 1 1 ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 (rows - 2) (cols - 2) chx ]
|
||||
= Network '[ Pad 3 3 3 3, Convolution channels chx 7 7 1 1 ]
|
||||
'[ 'D3 rows cols channels, 'D3 (rows + 6) (cols + 6) channels, 'D3 rows cols chx ]
|
||||
|
||||
|
@ -9,7 +9,7 @@ module Grenade.Layers.Trivial (
|
||||
|
||||
import Data.Serialize
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core
|
||||
|
||||
-- | A trivial layer.
|
||||
data Trivial = Trivial
|
||||
@ -25,5 +25,6 @@ instance UpdateLayer Trivial where
|
||||
createRandom = return Trivial
|
||||
|
||||
instance (a ~ b) => Layer Trivial a b where
|
||||
runForwards _ = id
|
||||
type Tape Trivial a b = ()
|
||||
runForwards _ a = ((), a)
|
||||
runBackwards _ _ y = ((), y)
|
||||
|
@ -1,9 +1,7 @@
|
||||
module Grenade.Recurrent (
|
||||
module X
|
||||
module Grenade.Recurrent.Core
|
||||
, module Grenade.Recurrent.Layers
|
||||
) where
|
||||
|
||||
import Grenade.Recurrent.Core.Layer as X
|
||||
import Grenade.Recurrent.Core.Network as X
|
||||
import Grenade.Recurrent.Core.Runner as X
|
||||
import Grenade.Recurrent.Layers.BasicRecurrent as X
|
||||
import Grenade.Recurrent.Layers.LSTM as X
|
||||
import Grenade.Recurrent.Core
|
||||
import Grenade.Recurrent.Layers
|
||||
|
@ -1,6 +1,9 @@
|
||||
module Grenade.Recurrent.Core (
|
||||
module X
|
||||
module Grenade.Recurrent.Core.Layer
|
||||
, module Grenade.Recurrent.Core.Network
|
||||
, module Grenade.Recurrent.Core.Runner
|
||||
) where
|
||||
|
||||
import Grenade.Recurrent.Core.Layer as X
|
||||
import Grenade.Recurrent.Core.Network as X
|
||||
import Grenade.Recurrent.Core.Layer
|
||||
import Grenade.Recurrent.Core.Network
|
||||
import Grenade.Recurrent.Core.Runner
|
||||
|
@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
@ -10,10 +9,7 @@
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
|
||||
#endif
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
module Grenade.Recurrent.Core.Network (
|
||||
Recurrent
|
||||
|
@ -3,16 +3,11 @@
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
|
||||
#endif
|
||||
|
||||
module Grenade.Recurrent.Core.Runner (
|
||||
trainRecurrent
|
||||
, runRecurrent
|
||||
|
7
src/Grenade/Recurrent/Layers.hs
Normal file
7
src/Grenade/Recurrent/Layers.hs
Normal file
@ -0,0 +1,7 @@
|
||||
module Grenade.Recurrent.Layers (
|
||||
module Grenade.Recurrent.Layers.BasicRecurrent
|
||||
, module Grenade.Recurrent.Layers.LSTM
|
||||
) where
|
||||
|
||||
import Grenade.Recurrent.Layers.BasicRecurrent
|
||||
import Grenade.Recurrent.Layers.LSTM
|
@ -1,4 +1,3 @@
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
@ -8,10 +7,6 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
-- GHC 7.10 doesn't see recurrent run functions as total.
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
|
||||
#endif
|
||||
module Grenade.Recurrent.Layers.BasicRecurrent (
|
||||
BasicRecurrent (..)
|
||||
, randomBasicRecurrent
|
||||
|
@ -1,5 +1,4 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
@ -11,11 +10,6 @@
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
-- GHC 7.10 doesn't see recurrent run functions as total.
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
|
||||
#endif
|
||||
|
||||
module Grenade.Recurrent.Layers.LSTM (
|
||||
LSTM (..)
|
||||
, LSTMWeights (..)
|
||||
|
@ -30,6 +30,15 @@ prop_pad_crop =
|
||||
(_ , grad) = runBackwards net tapes d
|
||||
in d ~~~ res .&&. grad ~~~ d
|
||||
|
||||
prop_pad_crop_2d :: Property
|
||||
prop_pad_crop_2d =
|
||||
let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ]
|
||||
net = Pad :~> Crop :~> NNil
|
||||
in gamble genOfShape $ \(d :: S ('D2 7 9)) ->
|
||||
let (tapes, res) = runForwards net d
|
||||
(_ , grad) = runBackwards net tapes d
|
||||
in d ~~~ res .&&. grad ~~~ d
|
||||
|
||||
(~~~) :: S x -> S x -> Bool
|
||||
(S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001
|
||||
(S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001
|
||||
|
Loading…
Reference in New Issue
Block a user