mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Update hmatrix
This commit is contained in:
parent
2cb93c6a5b
commit
d360438fc0
@ -76,7 +76,7 @@ executable mnist
|
||||
, optparse-applicative == 0.12.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, singletons
|
||||
, MonadRandom
|
||||
|
@ -1 +0,0 @@
|
||||
Subproject commit 9aade51bd0bb6339cfa8aca014bd96f801d9b19e
|
@ -16,7 +16,7 @@ import Data.Singletons.Prelude
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
-- | Drive and network and collect it's back propogated gradients.
|
||||
-- | Drive and network and collect its back propogated gradients.
|
||||
backPropagate :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
|
||||
=> Network layers shapes -> S' input -> S' output -> Gradients layers
|
||||
backPropagate network input target =
|
||||
@ -29,7 +29,7 @@ backPropagate network input target =
|
||||
-- handle input from the beginning, feeding upwards.
|
||||
go !x (layer :~> n)
|
||||
= let y = runForwards layer x
|
||||
-- recursively run the rest of the network, and get the layer from above.
|
||||
-- recursively run the rest of the network, and get the gradients from above.
|
||||
(n', dWs') = go y n
|
||||
-- calculate the gradient for this layer to pass down,
|
||||
(layer', dWs) = runBackards layer x dWs'
|
||||
|
Loading…
Reference in New Issue
Block a user