From c7a1542b028b13d95327712afa817a1828c8d417 Mon Sep 17 00:00:00 2001 From: Andre Martins Date: Fri, 16 Sep 2016 14:25:06 +0100 Subject: [PATCH 1/4] Separating graph from data. --- src/validate_encoder_decoder.cu | 165 +++++++++++++++++--------------- 1 file changed, 90 insertions(+), 75 deletions(-) diff --git a/src/validate_encoder_decoder.cu b/src/validate_encoder_decoder.cu index c7be225a..d8eea261 100644 --- a/src/validate_encoder_decoder.cu +++ b/src/validate_encoder_decoder.cu @@ -2,11 +2,78 @@ #include "marian.h" #include "mnist.h" -#if 0 -ExpressionGraph build_graph() { - std::cerr << "Loading model params..."; +using namespace marian; +using namespace keywords; + +const int input_size = 10; +const int output_size = 15; +const int batch_size = 25; +const int hidden_size = 5; +const int num_inputs = 8; +const int num_outputs = 6; + +ExpressionGraph build_graph(int cuda_device) { + std::cerr << "Building computation graph..." << std::endl; + + ExpressionGraph g(cuda_device); + std::vector X, Y, H, S; + + // For the stop symbol. + for (int t = 0; t <= num_inputs; ++t) { + std::stringstream ss; + ss << "X" << t; + X.emplace_back(named(g.input(shape={batch_size, input_size}), ss.str())); + } + + // For the stop symbol. + for (int t = 0; t <= num_outputs; ++t) { + std::stringstream ss; + ss << "Y" << t; + Y.emplace_back(named(g.input(shape={batch_size, output_size}), ss.str())); + } + + Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh"); + Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh"); + Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh"); + Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0"); + + std::cerr << "Building encoder RNN..." << std::endl; + H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh)); + for (int t = 1; t <= num_inputs; ++t) { + H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh)); + } + + Expr Wxh_d = g.param(shape={output_size, hidden_size}, init=uniform(), name="Wxh_d"); + Expr Whh_d = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh_d"); + Expr bh_d = g.param(shape={1, hidden_size}, init=uniform(), name="bh_d"); + + std::cerr << "Building decoder RNN..." << std::endl; + auto h0_d = H[num_inputs]; + S.emplace_back(tanh(dot(Y[0], Wxh_d) + dot(h0_d, Whh_d) + bh_d)); + for (int t = 1; t < num_outputs; ++t) { + S.emplace_back(tanh(dot(Y[t], Wxh_d) + dot(S[t-1], Whh_d) + bh_d)); + } + + Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why"); + Expr by = g.param(shape={1, output_size}, init=uniform(), name="by"); + + std::cerr << "Building output layer..." << std::endl; + std::vector Yp; + + Yp.emplace_back(named(softmax_fast(dot(h0_d, Why) + by), "pred")); + Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1); + for (int t = 1; t <= num_outputs; ++t) { + Yp.emplace_back(named(softmax_fast(dot(S[t-1], Why) + by), "pred")); + cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1); + } + auto graph = -mean(cross_entropy, axis=0, name="cost"); + + std::cerr << "Done." << std::endl; + + return g; } +#if 0 // read parallel corpus from file std::fstream sourceFile("../examples/mt/dev/newstest2013.de"); std::fstream targetFile("../examples/mt/dev/newstest2013.en"); @@ -21,73 +88,8 @@ ExpressionGraph build_graph() { int main(int argc, char** argv) { - cudaSetDevice(0); - using namespace marian; - using namespace keywords; - - int input_size = 10; - int output_size = 15; - int batch_size = 25; - int hidden_size = 5; - int num_inputs = 8; - int num_outputs = 6; - - ExpressionGraph g; - std::vector X(num_inputs+1); // For the stop symbol. - std::vector Y(num_outputs); - std::vector H(num_inputs+1); // For the stop symbol. - std::vector S(num_outputs); - - // For the stop symbol. - for (int t = 0; t <= num_inputs; ++t) { - X[t] = new Expr(g.input(shape={batch_size, input_size})); - } - - // For the stop symbol. - for (int t = 0; t <= num_outputs; ++t) { - Y[t] = new Expr(g.input(shape={batch_size, output_size})); - } - - Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh"); - Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh"); - Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh"); - Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0"); - - std::cerr << "Building encoder RNN..." << std::endl; - H[0] = new Expr(tanh(dot(*X[0], Wxh) + dot(h0, Whh) + bh)); - for (int t = 1; t <= num_inputs; ++t) { - H[t] = new Expr(tanh(dot(*X[t], Wxh) + dot(*H[t-1], Whh) + bh)); - } - - Expr Wxh_d = g.param(shape={output_size, hidden_size}, init=uniform(), name="Wxh_d"); - Expr Whh_d = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh_d"); - Expr bh_d = g.param(shape={1, hidden_size}, init=uniform(), name="bh_d"); - - std::cerr << "Building decoder RNN..." << std::endl; - auto h0_d = *H[num_inputs]; - S[0] = new Expr(tanh(dot(*Y[0], Wxh_d) + dot(h0_d, Whh_d) + bh_d)); - for (int t = 1; t < num_outputs; ++t) { - S[t] = new Expr(tanh(dot(*Y[t], Wxh_d) + dot(*S[t-1], Whh_d) + bh_d)); - } - - Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why"); - Expr by = g.param(shape={1, output_size}, init=uniform(), name="by"); - - std::cerr << "Building output layer..." << std::endl; - std::vector Yp(num_outputs+1); // For the stop symbol. - - Expr* cross_entropy = NULL; - for (int t = 0; t <= num_outputs; ++t) { - if (t == 0) { - Yp[t] = new Expr(named(softmax_fast(dot(h0_d, Why) + by), "pred")); - cross_entropy = new Expr(sum(*Y[t] * log(*Yp[t]), axis=1)); - } else { - Yp[t] = new Expr(named(softmax_fast(dot(*S[t-1], Why) + by), "pred")); - *cross_entropy = *cross_entropy + sum(*Y[t] * log(*Yp[t]), axis=1); - } - } - auto graph = -mean(*cross_entropy, axis=0, name="cost"); + ExpressionGraph g = build_graph(0); // For the stop symbol. for (int t = 0; t <= num_inputs; ++t) { @@ -105,10 +107,13 @@ int main(int argc, char** argv) { thrust::copy(values.begin(), values.end(), Xt.begin()); - *X[t] = Xt; + std::stringstream ss; + ss << "X" << t; + g[ss.str()] = Xt; + } - for (int t = 0; t < num_outputs; ++t) { + for (int t = 0; t <= num_outputs; ++t) { Tensor Yt({batch_size, output_size}); std::vector classes(batch_size * output_size, 0.0); @@ -121,23 +126,33 @@ int main(int argc, char** argv) { thrust::copy(classes.begin(), classes.end(), Yt.begin()); - *Y[t] = Yt; + std::stringstream ss; + ss << "Y" << t; + g[ss.str()] = Yt; } + std::cerr << "Graphviz step" << std::endl; + std::cout << g.graphviz() << std::endl; + + std::cerr << "Forward step" << std::endl; g.forward(batch_size); + std::cerr << "Backward step" << std::endl; g.backward(); + std::cerr << "Done" << std::endl; - std::cerr << graph.val().Debug() << std::endl; + std::cerr << g["graph"].val().Debug() << std::endl; - std::cerr << X[0]->val().Debug() << std::endl; - std::cerr << Y[0]->val().Debug() << std::endl; + std::cerr << g["X0"].val().Debug() << std::endl; + std::cerr << g["Y0"].val().Debug() << std::endl; +#if 0 std::cerr << Whh.grad().Debug() << std::endl; std::cerr << bh.grad().Debug() << std::endl; std::cerr << Why.grad().Debug() << std::endl; std::cerr << by.grad().Debug() << std::endl; std::cerr << Wxh.grad().Debug() << std::endl; std::cerr << h0.grad().Debug() << std::endl; +#endif return 0; } From b6c25a8db709942d018b44b31139cc4389770082 Mon Sep 17 00:00:00 2001 From: Maximiliana Behnke Date: Fri, 16 Sep 2016 16:49:05 +0200 Subject: [PATCH 2/4] Add Doxygen comments to tensor.h --- src/tensor.h | 254 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 247 insertions(+), 7 deletions(-) diff --git a/src/tensor.h b/src/tensor.h index a32e9b04..a18aaa3f 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -1,4 +1,21 @@ #pragma once +/* Copyright (C) + * 2016 - MLAMU & friends + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. + * + */ #include #include @@ -12,6 +29,13 @@ namespace marian { +/** + * @brief Debug shape by printing it. + * + * @param shape Shape of Tensor. + * + * @return + */ inline std::string Debug(const Shape &shape) { std::stringstream strm; @@ -23,6 +47,13 @@ inline std::string Debug(const Shape &shape) return strm.str(); } +/** + * @brief Calculate the vector size based on Tensor shape. + * + * @param shape Shape of Tensor. + * + * @return + */ inline size_t GetTotalSize(const Shape &shape) { size_t ret = std::accumulate(shape.begin(), shape.end(), @@ -30,17 +61,28 @@ inline size_t GetTotalSize(const Shape &shape) return ret; } +/** + * @brief Class that manages the Tensor on the GPU. + * + * @tparam Float Data type. + */ template class TensorImpl { private: - Shape shape_; - thrust::device_vector data_; - size_t tno_; - static size_t tensorCounter; + Shape shape_; /*!< Dimens of Tensor */ + thrust::device_vector data_; /*< Vector of data that Tensor is managing on GPU. */ + size_t tno_; /*< Tensor number */ + static size_t tensorCounter; /*< Static counter of created Tensors */ public: - typedef Float value_type; + typedef Float value_type; /*< Tensor value type */ + /** + * @brief Constructor + * + * @param shape Shape of Tensor. + * @param value Value to fill Tensor's vector with. + */ TensorImpl(const Shape& shape, value_type value = 0) : shape_(shape), tno_(tensorCounter++) { @@ -59,54 +101,122 @@ class TensorImpl { TensorImpl(const TensorImpl&) = delete; TensorImpl(TensorImpl&&) = delete; + /** + * @brief Get value of vector specified with index. + * + * @param i Index. + * + * @return Value of Tensor vector indexed with i. + */ value_type operator[](size_t i) const { return data_[i]; } + /** + * @brief Get begin iterator of Tensor's vector. + * + * @return Vector begin iterator. + */ auto begin() -> decltype( data_.begin() ) { return data_.begin(); } + /** + * @brief Get begin iterator of Tensor's vector (const). + * + * @return Vector begin iterator (const) + */ auto begin() const -> decltype( data_.begin() ) { return data_.begin(); } + /** + * @brief Get end iterator of Tensor's vector. + * + * @return Vector end iterator + */ auto end() -> decltype( data_.end() ) { return data_.end(); } + /** + * @brief Get end iterator of Tensor's vector (const). + * + * @return Vector end iterator (const) + */ auto end() const -> decltype( data_.end() ) { return data_.end(); } + /** + * @brief Get Tensor's shape (const) + * + * @return Shape of Tensor + */ const Shape& shape() const { return shape_; } + /** + * @brief Get size of Tensor's vector. + * + * @return Length of Tensor's vector. + */ size_t size() const { return data_.size(); } + /** + * @brief Cast data from Tensor's GPU to value_type. + * + * @return Pointer of value_type array. + */ value_type* data() { return thrust::raw_pointer_cast(data_.data()); } + /** + * @brief Get Tensor id (number). + * + * @return Tensor id. + */ size_t id() const { return tno_; } + /** + * @brief Fill Tensor's vector with specified value on the GPU. + * + * @param value Value to fill vector with. + */ void set(value_type value) { thrust::fill(data_.begin(), data_.end(), value); } + /** + * @brief Set Tensor's vector to values of specified vector by copying it to GPU. + * + * @param begin Begin iterator of a vector. + * @param end End iterator of a vector. + */ void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) { thrust::copy(begin, end, data_.begin()); } + /** + * @brief Copy Tensor's vector from GPU to vector variable on CPU. + * + * @param out Vector to copy data to. + */ void get(std::vector::iterator out) { thrust::copy(data_.begin(), data_.end(), out); } + /** + * @brief Debug function. + * + * @return Vector in string form. + */ std::string Debug() const { std::stringstream strm; @@ -133,78 +243,170 @@ class TensorImpl { template size_t TensorImpl::tensorCounter = 0; +/** + * @brief Class that communicates with GPU's Tensor. + */ class Tensor { private: - std::shared_ptr> pimpl_; + std::shared_ptr> pimpl_; /*< Pointer to Tensor working on GPU */ public: - typedef TensorImpl::value_type value_type; + typedef TensorImpl::value_type value_type; /*< Get value type of GPU's Tensor data */ + /** + * @brief Default constructor + */ Tensor() {} + + /** + * @brief Constructor that allocates needed memory. + * + * @param shape Shape of Tensor. + * @param value Value to fill Tensor's vector with. + */ Tensor(const Shape& shape, value_type value = 0) { allocate(shape, value); } + /** + * @brief Default destructor + */ ~Tensor() {} + /** + * @brief Allocate memory if Tensor doesn't exist on GPU. + * + * @param shape Shape of Tensor. + * @param value Value to fill Tensor's vector with. + */ void allocate(const Shape& shape, value_type value = 0) { if(!pimpl_) pimpl_.reset(new TensorImpl(shape, value)); } + /** + * @brief Get value of GPU Tensor in specified index (const). + * + * @param i Index. + * + * @return Value of specified element of Tensor. + */ value_type operator[](size_t i) const { return (*pimpl_)[i]; } + /** + * @brief Get size of GPU Tensor's vector. + * + * @return + */ size_t size() const { return pimpl_->size(); } + /** + * @brief Return pointer to GPU Tensor's data. + * + * @return Pointer to GPU Tensor's data. + */ value_type* data() { return pimpl_->data(); } + /** + * @brief Return pointer to GPU Tensor's data (const). + * + * @return Pointer to GPU Tensor's data. + */ const value_type* data() const { return pimpl_->data(); } + /** + * @brief Get begin iterator of GPU Tensor's vector. + * + * @return Vector begin iterator. + */ auto begin() -> decltype( pimpl_->begin() ) { return pimpl_->begin(); } + /** + * @brief Get begin iterator of GPU Tensor's vector (const). + * + * @return Vector begin iterator (const) + */ auto begin() const -> decltype( pimpl_->begin() ) { return pimpl_->begin(); } + /** + * @brief Get end iterator of Tensor's vector. + * + * @return Vector end iterator + */ auto end() -> decltype( pimpl_->end() ) { return pimpl_->end(); } + /** + * @brief Get end iterator of Tensor's vector (const). + * + * @return Vector end iterator (const) + */ auto end() const -> decltype( pimpl_->end() ) { return pimpl_->end(); } + /** + * @brief Get GPU Tensor's shape. + * + * @return Tensor's shape. + */ const Shape& shape() const { return pimpl_->shape(); } + /** + * @brief Fill GPU Tensor's vector with specified value. + * + * @param value Value to fill Tensor with. + */ void set(value_type value) { pimpl_->set(value); } + /** + * @brief Get GPU Tensor id (number). + * + * @return Tensor id. + */ size_t id() const { return pimpl_->id(); } + /** + * @brief Check if Tensor exists (is filled with data). + * + * @return True or False + */ operator bool() { return pimpl_ != nullptr; } + /** + * @brief Run Debug on GPU Tensor. + * + * @return String of Tensor's data. + */ std::string Debug() const { return pimpl_->Debug(); } + /** + * @brief Print Tensor data on CPU (?) (const). + */ void Print() const { for (int i = 0; i < size(); ++i) { std::cerr << (*this)[i] << " "; @@ -213,21 +415,59 @@ class Tensor { } //void Load(const std::string &path); + + /** + * @brief Set GPU Tensor's vector to values of specified vector. + * + * @param data Vector copied to GPU. + */ void set(const std::vector& data); + /** + * @brief Set GPU Tensor's veector to values of specified vector. + * + * @param begin Begin iterator of vector being copied. + * @param end End iterator of vector being copied. + */ void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); + /** + * @brief Copy Tensor's vector from GPU to vector variable on CPU (const). + * + * @param out Vector iterator used in copying. + */ void get(std::vector::iterator out) const { pimpl_->get(out); } + /** + * @brief Copy Tensor's vector from GPU to vector variable on CPU. + * + * @param out Vector to copy data to. + */ void get(std::vector &vout) const { vout.resize(size()); pimpl_->get(vout.begin()); } }; +/** + * @brief Operator to set data on Tensor using vector. + * + * @param t Tensor. + * @param vec Vector used to set data in Tensor. + * + * @return Tensor with assigned data. + */ Tensor& operator<<(Tensor& t, const std::vector &vec); +/** + * @brief Operator to get data from Tensor to vector. + * + * @param vec Vector to save copied data. + * @param t Tensor to copy data from. + * + * @return Vector with copied data. + */ std::vector& operator<<(std::vector &vec, const Tensor& t); } From 18018841f3d953303bb0583c6c8b639c4b7db8e5 Mon Sep 17 00:00:00 2001 From: Maximiliana Behnke Date: Fri, 16 Sep 2016 17:10:02 +0200 Subject: [PATCH 3/4] Cosmetic changes to comments in tensor.h --- src/tensor.h | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/tensor.h b/src/tensor.h index a18aaa3f..d083435e 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -34,7 +34,7 @@ namespace marian { * * @param shape Shape of Tensor. * - * @return + * @return String of shape. */ inline std::string Debug(const Shape &shape) { @@ -52,7 +52,7 @@ inline std::string Debug(const Shape &shape) * * @param shape Shape of Tensor. * - * @return + * @return Size of Tensor vector. */ inline size_t GetTotalSize(const Shape &shape) { @@ -62,14 +62,14 @@ inline size_t GetTotalSize(const Shape &shape) } /** - * @brief Class that manages the Tensor on the GPU. + * @brief This class manages the Tensor on the GPU. * * @tparam Float Data type. */ template class TensorImpl { private: - Shape shape_; /*!< Dimens of Tensor */ + Shape shape_; /*!< Dimenions of Tensor */ thrust::device_vector data_; /*< Vector of data that Tensor is managing on GPU. */ size_t tno_; /*< Tensor number */ static size_t tensorCounter; /*< Static counter of created Tensors */ @@ -102,7 +102,7 @@ class TensorImpl { TensorImpl(TensorImpl&&) = delete; /** - * @brief Get value of vector specified with index. + * @brief Get the i-th element of Tensor vector. * * @param i Index. * @@ -259,7 +259,7 @@ class Tensor { Tensor() {} /** - * @brief Constructor that allocates needed memory. + * @brief Constructor that allocates memory. * * @param shape Shape of Tensor. * @param value Value to fill Tensor's vector with. @@ -274,7 +274,7 @@ class Tensor { ~Tensor() {} /** - * @brief Allocate memory if Tensor doesn't exist on GPU. + * @brief Allocate memory if Tensor doesn't exist on GPU. Otherwise, do nothing. * * @param shape Shape of Tensor. * @param value Value to fill Tensor's vector with. @@ -285,7 +285,7 @@ class Tensor { } /** - * @brief Get value of GPU Tensor in specified index (const). + * @brief Get i-th element of GPU Tensor vector (const). * * @param i Index. * @@ -298,7 +298,7 @@ class Tensor { /** * @brief Get size of GPU Tensor's vector. * - * @return + * @return Size of Tensor vector. */ size_t size() const { return pimpl_->size(); @@ -386,7 +386,7 @@ class Tensor { } /** - * @brief Check if Tensor exists (is filled with data). + * @brief Check if Tensor is allocated. * * @return True or False */ @@ -423,7 +423,7 @@ class Tensor { */ void set(const std::vector& data); /** - * @brief Set GPU Tensor's veector to values of specified vector. + * @brief Fill GPU Tensor's vector using values from the specified vector. * * @param begin Begin iterator of vector being copied. * @param end End iterator of vector being copied. From 1b27accaa04c072628108e4ce84029baff18f550 Mon Sep 17 00:00:00 2001 From: Andre Martins Date: Fri, 16 Sep 2016 16:35:17 +0100 Subject: [PATCH 4/4] Included embedding layer and graphviz part in the e-d. --- src/validate_encoder_decoder.cu | 73 ++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/src/validate_encoder_decoder.cu b/src/validate_encoder_decoder.cu index ded9982e..1df1a897 100644 --- a/src/validate_encoder_decoder.cu +++ b/src/validate_encoder_decoder.cu @@ -9,8 +9,9 @@ using namespace keywords; const int input_size = 10; const int output_size = 15; -const int batch_size = 25; +const int embedding_size = 8; const int hidden_size = 5; +const int batch_size = 25; const int num_inputs = 8; const int num_outputs = 6; @@ -20,34 +21,47 @@ ExpressionGraph build_graph(int cuda_device) { ExpressionGraph g(cuda_device); std::vector X, Y, H, S; - // For the stop symbol. + // We're including the stop symbol here. for (int t = 0; t <= num_inputs; ++t) { std::stringstream ss; ss << "X" << t; X.emplace_back(named(g.input(shape={batch_size, input_size}), ss.str())); } - // For the stop symbol. + // We're including the stop symbol here. for (int t = 0; t <= num_outputs; ++t) { std::stringstream ss; ss << "Y" << t; Y.emplace_back(named(g.input(shape={batch_size, output_size}), ss.str())); } - Expr Wxh = named(g.param(shape={input_size, hidden_size}, init=uniform()), "Wxh"); - Expr Whh = named(g.param(shape={hidden_size, hidden_size}, init=uniform()), "Whh"); - Expr bh = named(g.param(shape={1, hidden_size}, init=uniform()), "bh"); - Expr h0 = named(g.param(shape={1, hidden_size}, init=uniform()), "h0"); + // Source embeddings. + Expr E = named(g.param(shape={input_size, embedding_size}, + init=uniform()), "E"); + + // Source RNN parameters. + Expr Wxh = named(g.param(shape={embedding_size, hidden_size}, + init=uniform()), "Wxh"); + Expr Whh = named(g.param(shape={hidden_size, hidden_size}, + init=uniform()), "Whh"); + Expr bh = named(g.param(shape={1, hidden_size}, + init=uniform()), "bh"); + Expr h0 = named(g.param(shape={1, hidden_size}, + init=uniform()), "h0"); std::cerr << "Building encoder RNN..." << std::endl; - H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh)); + H.emplace_back(tanh(dot(dot(X[0], E), Wxh) + dot(h0, Whh) + bh)); for (int t = 1; t <= num_inputs; ++t) { - H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh)); + H.emplace_back(tanh(dot(dot(X[t], E), Wxh) + dot(H[t-1], Whh) + bh)); } - Expr Wxh_d = named(g.param(shape={output_size, hidden_size}, init=uniform()), "Wxh_d"); - Expr Whh_d = named(g.param(shape={hidden_size, hidden_size}, init=uniform()), "Whh_d"); - Expr bh_d = named(g.param(shape={1, hidden_size}, init=uniform()), "bh_d"); + // Target RNN parameters. + Expr Wxh_d = named(g.param(shape={output_size, hidden_size}, + init=uniform()), "Wxh_d"); + Expr Whh_d = named(g.param(shape={hidden_size, hidden_size}, + init=uniform()), "Whh_d"); + Expr bh_d = named(g.param(shape={1, hidden_size}, + init=uniform()), "bh_d"); std::cerr << "Building decoder RNN..." << std::endl; auto h0_d = H[num_inputs]; @@ -56,12 +70,16 @@ ExpressionGraph build_graph(int cuda_device) { S.emplace_back(tanh(dot(Y[t], Wxh_d) + dot(S[t-1], Whh_d) + bh_d)); } - Expr Why = named(g.param(shape={hidden_size, output_size}, init=uniform()), "Why"); - Expr by = named(g.param(shape={1, output_size}, init=uniform()), "by"); + // Output linear layer before softmax. + Expr Why = named(g.param(shape={hidden_size, output_size}, + init=uniform()), "Why"); + Expr by = named(g.param(shape={1, output_size}, + init=uniform()), "by"); std::cerr << "Building output layer..." << std::endl; - std::vector Yp; + // Softmax layer and cost function. + std::vector Yp; Yp.emplace_back(named(softmax_fast(dot(h0_d, Why) + by), "pred")); Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1); for (int t = 1; t <= num_outputs; ++t) { @@ -75,8 +93,6 @@ ExpressionGraph build_graph(int cuda_device) { return g; } - - int main(int argc, char** argv) { #if 1 std::cerr << "Loading the data... "; @@ -102,12 +118,12 @@ int main(int argc, char** argv) { std::cerr << "Target vocabulary size: " << targetVocab.Size() << std::endl; #endif + // Build the encoder-decoder computation graph. ExpressionGraph g = build_graph(0); - // For the stop symbol. + // Generate input data (include the stop symbol). for (int t = 0; t <= num_inputs; ++t) { Tensor Xt({batch_size, input_size}); - float max = 1.; std::vector values(batch_size * input_size); std::vector classes(batch_size * output_size, 0.0); @@ -117,16 +133,13 @@ int main(int argc, char** argv) { values[k] = max * (2.0*static_cast(rand()) / RAND_MAX - 1.0); } } - thrust::copy(values.begin(), values.end(), Xt.begin()); - std::stringstream ss; ss << "X" << t; - if (!g.has_node(ss.str())) std::cerr << "No node " << ss.str() << "!!!" << std::endl; g[ss.str()] = Xt; - } + // Generate output data (include the stop symbol). for (int t = 0; t <= num_outputs; ++t) { Tensor Yt({batch_size, output_size}); @@ -137,37 +150,31 @@ int main(int argc, char** argv) { classes[l + gold] = 1.0; l += output_size; } - thrust::copy(classes.begin(), classes.end(), Yt.begin()); - std::stringstream ss; ss << "Y" << t; - if (!g.has_node(ss.str())) std::cerr << "No node " << ss.str() << "!!!" << std::endl; g[ss.str()] = Yt; } - std::cerr << "Graphviz step" << std::endl; + std::cerr << "Printing the computation graph..." << std::endl; std::cout << g.graphviz() << std::endl; - std::cerr << "Forward step" << std::endl; + std::cerr << "Running the forward step..." << std::endl; g.forward(batch_size); - std::cerr << "Backward step" << std::endl; + std::cerr << "Running the backward step..." << std::endl; g.backward(); - std::cerr << "Done" << std::endl; + std::cerr << "Done." << std::endl; std::cerr << g["cost"].val().Debug() << std::endl; std::cerr << g["X0"].val().Debug() << std::endl; std::cerr << g["Y0"].val().Debug() << std::endl; - -#if 1 std::cerr << g["Whh"].grad().Debug() << std::endl; std::cerr << g["bh"].grad().Debug() << std::endl; std::cerr << g["Why"].grad().Debug() << std::endl; std::cerr << g["by"].grad().Debug() << std::endl; std::cerr << g["Wxh"].grad().Debug() << std::endl; std::cerr << g["h0"].grad().Debug() << std::endl; -#endif return 0; }