parse parallel data

This commit is contained in:
Hieu Hoang 2016-09-16 13:29:49 +02:00
parent 383b82c6f9
commit dfae0a2585
4 changed files with 88 additions and 1 deletions

View File

@ -9,6 +9,7 @@ cuda_add_library(marian_lib
tensor.cu
tensor_operators.cu
expression_operators.cu
vocab.cpp
)
target_link_libraries(marian_lib)

View File

@ -1,13 +1,17 @@
#include <fstream>
#include "marian.h"
#include "mnist.h"
#include "vocab.h"
int main(int argc, char** argv) {
cudaSetDevice(0);
using namespace std;
using namespace marian;
using namespace keywords;
Vocab sourceVocab, targetVocab;
int input_size = 10;
int output_size = 2;
int batch_size = 25;
@ -30,6 +34,18 @@ int main(int argc, char** argv) {
Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh");
Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0");
// read parallel corpus from file
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
std::fstream targetFile("../examples/mt/dev/newstest2013.en");
string sourceLine, targetLine;
while (getline(sourceFile, sourceLine)) {
getline(targetFile, targetLine);
std::vector<size_t> sourceIds = sourceVocab.ProcessSentence(sourceLine);
std::vector<size_t> targetIds = sourceVocab.ProcessSentence(targetLine);
}
std::cerr << "Building RNN..." << std::endl;
H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh));
for (int t = 1; t < num_inputs; ++t) {

53
src/vocab.cpp Normal file
View File

@ -0,0 +1,53 @@
#include "vocab.h"
using namespace std;
////////////////////////////////////////////////////////
inline std::vector<std::string> Tokenize(const std::string& str,
const std::string& delimiters = " \t")
{
std::vector<std::string> tokens;
// Skip delimiters at beginning.
std::string::size_type lastPos = str.find_first_not_of(delimiters, 0);
// Find first "non-delimiter".
std::string::size_type pos = str.find_first_of(delimiters, lastPos);
while (std::string::npos != pos || std::string::npos != lastPos) {
// Found a token, add it to the vector.
tokens.push_back(str.substr(lastPos, pos - lastPos));
// Skip delimiters. Note the "not_of"
lastPos = str.find_first_not_of(delimiters, pos);
// Find next "non-delimiter"
pos = str.find_first_of(delimiters, lastPos);
}
return tokens;
}
////////////////////////////////////////////////////////
size_t Vocab::GetOrCreate(const std::string &word)
{
size_t id;
Coll::const_iterator iter = coll_.find(word);
if (iter == coll_.end()) {
id = coll_.size();
coll_[word] = id;
}
else {
id = iter->second;
}
return id;
}
std::vector<size_t> Vocab::ProcessSentence(const std::string &sentence)
{
vector<string> toks = Tokenize(sentence);
vector<size_t> ret;
for (size_t i = 0; i < toks.size(); ++i) {
size_t id = GetOrCreate(toks[i]);
ret[i] = id;
}
return ret;
}

17
src/vocab.h Normal file
View File

@ -0,0 +1,17 @@
#pragma once
#include <unordered_map>
#include <string>
#include <vector>
class Vocab
{
public:
size_t GetOrCreate(const std::string &word);
std::vector<size_t> ProcessSentence(const std::string &sentence);
protected:
typedef std::unordered_map<std::string, size_t> Coll;
Coll coll_;
};