mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Merge branch 'master' of github.com:emjotde/marian
This commit is contained in:
commit
448c6f5e77
4
examples/mt/download.sh
Executable file
4
examples/mt/download.sh
Executable file
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
wget http://data.statmt.org/wmt16/translation-task/dev.tgz
|
||||||
|
tar xvf dev.tgz
|
||||||
|
|
@ -9,6 +9,7 @@ cuda_add_library(marian_lib
|
|||||||
tensor.cu
|
tensor.cu
|
||||||
tensor_operators.cu
|
tensor_operators.cu
|
||||||
expression_operators.cu
|
expression_operators.cu
|
||||||
|
vocab.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(marian_lib)
|
target_link_libraries(marian_lib)
|
||||||
@ -35,11 +36,16 @@ cuda_add_executable(
|
|||||||
validate_mnist_batch
|
validate_mnist_batch
|
||||||
validate_mnist_batch.cu
|
validate_mnist_batch.cu
|
||||||
)
|
)
|
||||||
|
cuda_add_executable(
|
||||||
|
validate_encoder_decoder
|
||||||
|
validate_encoder_decoder.cu
|
||||||
|
)
|
||||||
|
|
||||||
target_link_libraries(validate_mnist marian_lib)
|
target_link_libraries(validate_mnist marian_lib)
|
||||||
target_link_libraries(validate_mnist_batch marian_lib)
|
target_link_libraries(validate_mnist_batch marian_lib)
|
||||||
|
target_link_libraries(validate_encoder_decoder marian_lib)
|
||||||
|
|
||||||
foreach(exec marian train_mnist validate_mnist validate_mnist_batch )
|
foreach(exec marian train_mnist validate_mnist validate_mnist_batch validate_encoder_decoder)
|
||||||
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
|
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
|
||||||
cuda_add_cublas_to_target(${exec})
|
cuda_add_cublas_to_target(${exec})
|
||||||
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||||
|
18
src/test.cu
18
src/test.cu
@ -1,13 +1,17 @@
|
|||||||
|
#include <fstream>
|
||||||
#include "marian.h"
|
#include "marian.h"
|
||||||
#include "mnist.h"
|
#include "mnist.h"
|
||||||
|
#include "vocab.h"
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
cudaSetDevice(0);
|
cudaSetDevice(0);
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
using namespace marian;
|
using namespace marian;
|
||||||
using namespace keywords;
|
using namespace keywords;
|
||||||
|
|
||||||
|
Vocab sourceVocab, targetVocab;
|
||||||
|
|
||||||
int input_size = 10;
|
int input_size = 10;
|
||||||
int output_size = 2;
|
int output_size = 2;
|
||||||
int batch_size = 25;
|
int batch_size = 25;
|
||||||
@ -30,6 +34,18 @@ int main(int argc, char** argv) {
|
|||||||
Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh");
|
Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh");
|
||||||
Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0");
|
Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0");
|
||||||
|
|
||||||
|
// read parallel corpus from file
|
||||||
|
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
|
||||||
|
std::fstream targetFile("../examples/mt/dev/newstest2013.en");
|
||||||
|
|
||||||
|
string sourceLine, targetLine;
|
||||||
|
while (getline(sourceFile, sourceLine)) {
|
||||||
|
getline(targetFile, targetLine);
|
||||||
|
|
||||||
|
std::vector<size_t> sourceIds = sourceVocab.ProcessSentence(sourceLine);
|
||||||
|
std::vector<size_t> targetIds = sourceVocab.ProcessSentence(targetLine);
|
||||||
|
}
|
||||||
|
|
||||||
std::cerr << "Building RNN..." << std::endl;
|
std::cerr << "Building RNN..." << std::endl;
|
||||||
H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh));
|
H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh));
|
||||||
for (int t = 1; t < num_inputs; ++t) {
|
for (int t = 1; t < num_inputs; ++t) {
|
||||||
|
53
src/vocab.cpp
Normal file
53
src/vocab.cpp
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
#include "vocab.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
inline std::vector<std::string> Tokenize(const std::string& str,
|
||||||
|
const std::string& delimiters = " \t")
|
||||||
|
{
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
// Skip delimiters at beginning.
|
||||||
|
std::string::size_type lastPos = str.find_first_not_of(delimiters, 0);
|
||||||
|
// Find first "non-delimiter".
|
||||||
|
std::string::size_type pos = str.find_first_of(delimiters, lastPos);
|
||||||
|
|
||||||
|
while (std::string::npos != pos || std::string::npos != lastPos) {
|
||||||
|
// Found a token, add it to the vector.
|
||||||
|
tokens.push_back(str.substr(lastPos, pos - lastPos));
|
||||||
|
// Skip delimiters. Note the "not_of"
|
||||||
|
lastPos = str.find_first_not_of(delimiters, pos);
|
||||||
|
// Find next "non-delimiter"
|
||||||
|
pos = str.find_first_of(delimiters, lastPos);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
size_t Vocab::GetOrCreate(const std::string &word)
|
||||||
|
{
|
||||||
|
size_t id;
|
||||||
|
Coll::const_iterator iter = coll_.find(word);
|
||||||
|
if (iter == coll_.end()) {
|
||||||
|
id = coll_.size();
|
||||||
|
coll_[word] = id;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
id = iter->second;
|
||||||
|
}
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> Vocab::ProcessSentence(const std::string &sentence)
|
||||||
|
{
|
||||||
|
vector<string> toks = Tokenize(sentence);
|
||||||
|
vector<size_t> ret(toks.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < toks.size(); ++i) {
|
||||||
|
size_t id = GetOrCreate(toks[i]);
|
||||||
|
ret[i] = id;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
17
src/vocab.h
Normal file
17
src/vocab.h
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class Vocab
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
size_t GetOrCreate(const std::string &word);
|
||||||
|
std::vector<size_t> ProcessSentence(const std::string &sentence);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typedef std::unordered_map<std::string, size_t> Coll;
|
||||||
|
Coll coll_;
|
||||||
|
};
|
||||||
|
|
Loading…
Reference in New Issue
Block a user