diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c772e360..e58473cf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,6 +9,7 @@ cuda_add_library(marian_lib tensor.cu tensor_operators.cu expression_operators.cu + vocab.cpp ) target_link_libraries(marian_lib) diff --git a/src/test.cu b/src/test.cu index d27591be..7da85c9d 100644 --- a/src/test.cu +++ b/src/test.cu @@ -1,13 +1,17 @@ - +#include #include "marian.h" #include "mnist.h" +#include "vocab.h" int main(int argc, char** argv) { cudaSetDevice(0); + using namespace std; using namespace marian; using namespace keywords; + Vocab sourceVocab, targetVocab; + int input_size = 10; int output_size = 2; int batch_size = 25; @@ -30,6 +34,18 @@ int main(int argc, char** argv) { Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh"); Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0"); + // read parallel corpus from file + std::fstream sourceFile("../examples/mt/dev/newstest2013.de"); + std::fstream targetFile("../examples/mt/dev/newstest2013.en"); + + string sourceLine, targetLine; + while (getline(sourceFile, sourceLine)) { + getline(targetFile, targetLine); + + std::vector sourceIds = sourceVocab.ProcessSentence(sourceLine); + std::vector targetIds = sourceVocab.ProcessSentence(targetLine); + } + std::cerr << "Building RNN..." << std::endl; H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh)); for (int t = 1; t < num_inputs; ++t) { diff --git a/src/vocab.cpp b/src/vocab.cpp new file mode 100644 index 00000000..c4e76285 --- /dev/null +++ b/src/vocab.cpp @@ -0,0 +1,53 @@ +#include "vocab.h" + +using namespace std; + +//////////////////////////////////////////////////////// +inline std::vector Tokenize(const std::string& str, + const std::string& delimiters = " \t") +{ + std::vector tokens; + // Skip delimiters at beginning. + std::string::size_type lastPos = str.find_first_not_of(delimiters, 0); + // Find first "non-delimiter". + std::string::size_type pos = str.find_first_of(delimiters, lastPos); + + while (std::string::npos != pos || std::string::npos != lastPos) { + // Found a token, add it to the vector. + tokens.push_back(str.substr(lastPos, pos - lastPos)); + // Skip delimiters. Note the "not_of" + lastPos = str.find_first_not_of(delimiters, pos); + // Find next "non-delimiter" + pos = str.find_first_of(delimiters, lastPos); + } + + return tokens; +} +//////////////////////////////////////////////////////// + +size_t Vocab::GetOrCreate(const std::string &word) +{ + size_t id; + Coll::const_iterator iter = coll_.find(word); + if (iter == coll_.end()) { + id = coll_.size(); + coll_[word] = id; + } + else { + id = iter->second; + } + return id; +} + +std::vector Vocab::ProcessSentence(const std::string &sentence) +{ + vector toks = Tokenize(sentence); + vector ret; + + for (size_t i = 0; i < toks.size(); ++i) { + size_t id = GetOrCreate(toks[i]); + ret[i] = id; + } + + return ret; +} diff --git a/src/vocab.h b/src/vocab.h new file mode 100644 index 00000000..5e055511 --- /dev/null +++ b/src/vocab.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +class Vocab +{ +public: + size_t GetOrCreate(const std::string &word); + std::vector ProcessSentence(const std::string &sentence); + +protected: + typedef std::unordered_map Coll; + Coll coll_; +}; +