split Sentences class into its own file

This commit is contained in:
Hieu Hoang 2017-02-17 16:48:27 +00:00
parent b7c1e8b421
commit 1fee16f7df
15 changed files with 103 additions and 79 deletions

View File

@ -33,6 +33,7 @@ add_library(libcommon OBJECT
common/scorer.cpp common/scorer.cpp
common/search.cpp common/search.cpp
common/sentence.cpp common/sentence.cpp
common/sentences.cpp
common/types.cpp common/types.cpp
common/utils.cpp common/utils.cpp
common/vocab.cpp common/vocab.cpp

View File

@ -10,6 +10,7 @@
#include "common/threadpool.h" #include "common/threadpool.h"
#include "common/printer.h" #include "common/printer.h"
#include "common/sentence.h" #include "common/sentence.h"
#include "common/sentences.h"
#include "common/exception.h" #include "common/exception.h"
#include "common/translation_task.h" #include "common/translation_task.h"

View File

@ -1,5 +1,6 @@
#include "history.h" #include "history.h"
#include "sentence.h" #include "sentence.h"
#include "sentences.h"
namespace amunmt { namespace amunmt {

View File

@ -7,6 +7,8 @@
namespace amunmt { namespace amunmt {
class Sentences;
class History { class History {
private: private:
struct HypothesisCoord { struct HypothesisCoord {

View File

@ -11,6 +11,7 @@
namespace amunmt { namespace amunmt {
class God; class God;
class Sentences;
class State { class State {
public: public:

View File

@ -1,7 +1,6 @@
#include "common/search.h"
#include <boost/timer/timer.hpp> #include <boost/timer/timer.hpp>
#include "common/search.h"
#include "common/sentences.h"
#include "common/god.h" #include "common/god.h"
#include "common/history.h" #include "common/history.h"
#include "common/filter.h" #include "common/filter.h"

View File

@ -1,4 +1,3 @@
#include <algorithm>
#include "sentence.h" #include "sentence.h"
#include "god.h" #include "god.h"
#include "utils.h" #include "utils.h"
@ -51,47 +50,6 @@ const Words& Sentence::GetWords(size_t index) const {
return words_[index]; return words_[index];
} }
/////////////////////////////////////////////////////////
Sentences::Sentences()
: maxLength_(0)
{}
Sentences::~Sentences()
{}
void Sentences::push_back(SentencePtr sentence) {
const Words &words = sentence->GetWords(0);
size_t len = words.size();
if (len > maxLength_) {
maxLength_ = len;
}
coll_.push_back(sentence);
}
class LengthOrderer {
public:
bool operator()(const SentencePtr& a, const SentencePtr& b) const {
return a->GetWords(0).size() < b->GetWords(0).size();
}
};
void Sentences::SortByLength() {
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
}
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
{
SentencesPtr sentences(new Sentences());
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
for (size_t i = startInd; i < size(); ++i) {
SentencePtr sentence = at(i);
sentences->push_back(sentence);
}
coll_.resize(startInd);
return sentences;
}
} }

View File

@ -27,39 +27,6 @@ class Sentence {
using SentencePtr = std::shared_ptr<Sentence>; using SentencePtr = std::shared_ptr<Sentence>;
//////////////////////////////////////////////////////////////////
class Sentences;
using SentencesPtr = std::shared_ptr<Sentences>;
class Sentences {
public:
Sentences();
~Sentences();
void push_back(SentencePtr sentence);
SentencePtr at(size_t id) const {
return coll_.at(id);
}
size_t size() const {
return coll_.size();
}
size_t GetMaxLength() const {
return maxLength_;
}
void SortByLength();
SentencesPtr NextMiniBatch(size_t batchsize);
protected:
std::vector<SentencePtr> coll_;
size_t maxLength_;
Sentences(const Sentences &) = delete;
};
} }

48
src/common/sentences.cpp Normal file
View File

@ -0,0 +1,48 @@
#include <algorithm>
#include "sentences.h"
namespace amunmt {
Sentences::Sentences()
: maxLength_(0)
{}
Sentences::~Sentences()
{}
void Sentences::push_back(SentencePtr sentence) {
const Words &words = sentence->GetWords(0);
size_t len = words.size();
if (len > maxLength_) {
maxLength_ = len;
}
coll_.push_back(sentence);
}
class LengthOrderer {
public:
bool operator()(const SentencePtr& a, const SentencePtr& b) const {
return a->GetWords(0).size() < b->GetWords(0).size();
}
};
void Sentences::SortByLength() {
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
}
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
{
SentencesPtr sentences(new Sentences());
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
for (size_t i = startInd; i < size(); ++i) {
SentencePtr sentence = at(i);
sentences->push_back(sentence);
}
coll_.resize(startInd);
return sentences;
}
}

40
src/common/sentences.h Normal file
View File

@ -0,0 +1,40 @@
#pragma once
#include "sentence.h"
namespace amunmt {
class Sentences;
using SentencesPtr = std::shared_ptr<Sentences>;
class Sentences {
public:
Sentences();
~Sentences();
void push_back(SentencePtr sentence);
SentencePtr at(size_t id) const {
return coll_.at(id);
}
size_t size() const {
return coll_.size();
}
size_t GetMaxLength() const {
return maxLength_;
}
void SortByLength();
SentencesPtr NextMiniBatch(size_t batchsize);
protected:
std::vector<SentencePtr> coll_;
size_t maxLength_;
Sentences(const Sentences &) = delete;
};
}

View File

@ -12,6 +12,7 @@
#include "common/loader.h" #include "common/loader.h"
#include "common/scorer.h" #include "common/scorer.h"
#include "common/sentence.h" #include "common/sentence.h"
#include "common/sentences.h"
#include "cpu/mblas/matrix.h" #include "cpu/mblas/matrix.h"
#include "cpu/decoder/best_hyps.h" #include "cpu/decoder/best_hyps.h"

View File

@ -3,7 +3,7 @@
#include "../mblas/matrix.h" #include "../mblas/matrix.h"
#include "../dl4mt/model.h" #include "../dl4mt/model.h"
#include "../dl4mt/gru.h" #include "../dl4mt/gru.h"
namespace amunmt { namespace amunmt {
namespace CPU { namespace CPU {

View File

@ -1,6 +1,7 @@
#include <iostream> #include <iostream>
#include "common/god.h" #include "common/god.h"
#include "common/sentences.h"
#include "encoder_decoder.h" #include "encoder_decoder.h"
#include "gpu/mblas/matrix_functions.h" #include "gpu/mblas/matrix_functions.h"

View File

@ -1,4 +1,5 @@
#include "encoder.h" #include "encoder.h"
#include "common/sentences.h"
using namespace std; using namespace std;

View File

@ -7,6 +7,9 @@
#include "gpu/types-gpu.h" #include "gpu/types-gpu.h"
namespace amunmt { namespace amunmt {
class Sentences;
namespace GPU { namespace GPU {
class Encoder { class Encoder {