mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
split Sentences class into its own file
This commit is contained in:
parent
b7c1e8b421
commit
1fee16f7df
@ -33,6 +33,7 @@ add_library(libcommon OBJECT
|
||||
common/scorer.cpp
|
||||
common/search.cpp
|
||||
common/sentence.cpp
|
||||
common/sentences.cpp
|
||||
common/types.cpp
|
||||
common/utils.cpp
|
||||
common/vocab.cpp
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "common/threadpool.h"
|
||||
#include "common/printer.h"
|
||||
#include "common/sentence.h"
|
||||
#include "common/sentences.h"
|
||||
#include "common/exception.h"
|
||||
#include "common/translation_task.h"
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "history.h"
|
||||
#include "sentence.h"
|
||||
#include "sentences.h"
|
||||
|
||||
namespace amunmt {
|
||||
|
||||
|
@ -7,6 +7,8 @@
|
||||
|
||||
namespace amunmt {
|
||||
|
||||
class Sentences;
|
||||
|
||||
class History {
|
||||
private:
|
||||
struct HypothesisCoord {
|
||||
|
@ -11,6 +11,7 @@
|
||||
namespace amunmt {
|
||||
|
||||
class God;
|
||||
class Sentences;
|
||||
|
||||
class State {
|
||||
public:
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include "common/search.h"
|
||||
|
||||
#include <boost/timer/timer.hpp>
|
||||
|
||||
#include "common/search.h"
|
||||
#include "common/sentences.h"
|
||||
#include "common/god.h"
|
||||
#include "common/history.h"
|
||||
#include "common/filter.h"
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include <algorithm>
|
||||
#include "sentence.h"
|
||||
#include "god.h"
|
||||
#include "utils.h"
|
||||
@ -51,47 +50,6 @@ const Words& Sentence::GetWords(size_t index) const {
|
||||
return words_[index];
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////
|
||||
Sentences::Sentences()
|
||||
: maxLength_(0)
|
||||
{}
|
||||
|
||||
Sentences::~Sentences()
|
||||
{}
|
||||
|
||||
void Sentences::push_back(SentencePtr sentence) {
|
||||
const Words &words = sentence->GetWords(0);
|
||||
size_t len = words.size();
|
||||
if (len > maxLength_) {
|
||||
maxLength_ = len;
|
||||
}
|
||||
|
||||
coll_.push_back(sentence);
|
||||
}
|
||||
|
||||
class LengthOrderer {
|
||||
public:
|
||||
bool operator()(const SentencePtr& a, const SentencePtr& b) const {
|
||||
return a->GetWords(0).size() < b->GetWords(0).size();
|
||||
}
|
||||
};
|
||||
|
||||
void Sentences::SortByLength() {
|
||||
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
|
||||
}
|
||||
|
||||
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
|
||||
{
|
||||
SentencesPtr sentences(new Sentences());
|
||||
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
|
||||
for (size_t i = startInd; i < size(); ++i) {
|
||||
SentencePtr sentence = at(i);
|
||||
sentences->push_back(sentence);
|
||||
}
|
||||
|
||||
coll_.resize(startInd);
|
||||
return sentences;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -27,39 +27,6 @@ class Sentence {
|
||||
|
||||
using SentencePtr = std::shared_ptr<Sentence>;
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
class Sentences;
|
||||
using SentencesPtr = std::shared_ptr<Sentences>;
|
||||
|
||||
class Sentences {
|
||||
public:
|
||||
Sentences();
|
||||
~Sentences();
|
||||
|
||||
void push_back(SentencePtr sentence);
|
||||
|
||||
SentencePtr at(size_t id) const {
|
||||
return coll_.at(id);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return coll_.size();
|
||||
}
|
||||
|
||||
size_t GetMaxLength() const {
|
||||
return maxLength_;
|
||||
}
|
||||
|
||||
void SortByLength();
|
||||
|
||||
SentencesPtr NextMiniBatch(size_t batchsize);
|
||||
|
||||
protected:
|
||||
std::vector<SentencePtr> coll_;
|
||||
size_t maxLength_;
|
||||
|
||||
Sentences(const Sentences &) = delete;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
48
src/common/sentences.cpp
Normal file
48
src/common/sentences.cpp
Normal file
@ -0,0 +1,48 @@
|
||||
#include <algorithm>
|
||||
#include "sentences.h"
|
||||
|
||||
namespace amunmt {
|
||||
|
||||
Sentences::Sentences()
|
||||
: maxLength_(0)
|
||||
{}
|
||||
|
||||
Sentences::~Sentences()
|
||||
{}
|
||||
|
||||
void Sentences::push_back(SentencePtr sentence) {
|
||||
const Words &words = sentence->GetWords(0);
|
||||
size_t len = words.size();
|
||||
if (len > maxLength_) {
|
||||
maxLength_ = len;
|
||||
}
|
||||
|
||||
coll_.push_back(sentence);
|
||||
}
|
||||
|
||||
class LengthOrderer {
|
||||
public:
|
||||
bool operator()(const SentencePtr& a, const SentencePtr& b) const {
|
||||
return a->GetWords(0).size() < b->GetWords(0).size();
|
||||
}
|
||||
};
|
||||
|
||||
void Sentences::SortByLength() {
|
||||
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
|
||||
}
|
||||
|
||||
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
|
||||
{
|
||||
SentencesPtr sentences(new Sentences());
|
||||
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
|
||||
for (size_t i = startInd; i < size(); ++i) {
|
||||
SentencePtr sentence = at(i);
|
||||
sentences->push_back(sentence);
|
||||
}
|
||||
|
||||
coll_.resize(startInd);
|
||||
return sentences;
|
||||
}
|
||||
|
||||
}
|
||||
|
40
src/common/sentences.h
Normal file
40
src/common/sentences.h
Normal file
@ -0,0 +1,40 @@
|
||||
#pragma once
|
||||
#include "sentence.h"
|
||||
|
||||
namespace amunmt {
|
||||
|
||||
class Sentences;
|
||||
using SentencesPtr = std::shared_ptr<Sentences>;
|
||||
|
||||
class Sentences {
|
||||
public:
|
||||
Sentences();
|
||||
~Sentences();
|
||||
|
||||
void push_back(SentencePtr sentence);
|
||||
|
||||
SentencePtr at(size_t id) const {
|
||||
return coll_.at(id);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return coll_.size();
|
||||
}
|
||||
|
||||
size_t GetMaxLength() const {
|
||||
return maxLength_;
|
||||
}
|
||||
|
||||
void SortByLength();
|
||||
|
||||
SentencesPtr NextMiniBatch(size_t batchsize);
|
||||
|
||||
protected:
|
||||
std::vector<SentencePtr> coll_;
|
||||
size_t maxLength_;
|
||||
|
||||
Sentences(const Sentences &) = delete;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "common/loader.h"
|
||||
#include "common/scorer.h"
|
||||
#include "common/sentence.h"
|
||||
#include "common/sentences.h"
|
||||
|
||||
#include "cpu/mblas/matrix.h"
|
||||
#include "cpu/decoder/best_hyps.h"
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include "../mblas/matrix.h"
|
||||
#include "../dl4mt/model.h"
|
||||
#include "../dl4mt/gru.h"
|
||||
|
||||
|
||||
namespace amunmt {
|
||||
namespace CPU {
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "common/god.h"
|
||||
#include "common/sentences.h"
|
||||
|
||||
#include "encoder_decoder.h"
|
||||
#include "gpu/mblas/matrix_functions.h"
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "encoder.h"
|
||||
#include "common/sentences.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -7,6 +7,9 @@
|
||||
#include "gpu/types-gpu.h"
|
||||
|
||||
namespace amunmt {
|
||||
|
||||
class Sentences;
|
||||
|
||||
namespace GPU {
|
||||
|
||||
class Encoder {
|
||||
|
Loading…
Reference in New Issue
Block a user