mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-01 05:50:03 +03:00
split Sentences class into its own file
This commit is contained in:
parent
b7c1e8b421
commit
1fee16f7df
@ -33,6 +33,7 @@ add_library(libcommon OBJECT
|
|||||||
common/scorer.cpp
|
common/scorer.cpp
|
||||||
common/search.cpp
|
common/search.cpp
|
||||||
common/sentence.cpp
|
common/sentence.cpp
|
||||||
|
common/sentences.cpp
|
||||||
common/types.cpp
|
common/types.cpp
|
||||||
common/utils.cpp
|
common/utils.cpp
|
||||||
common/vocab.cpp
|
common/vocab.cpp
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include "common/threadpool.h"
|
#include "common/threadpool.h"
|
||||||
#include "common/printer.h"
|
#include "common/printer.h"
|
||||||
#include "common/sentence.h"
|
#include "common/sentence.h"
|
||||||
|
#include "common/sentences.h"
|
||||||
#include "common/exception.h"
|
#include "common/exception.h"
|
||||||
#include "common/translation_task.h"
|
#include "common/translation_task.h"
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include "history.h"
|
#include "history.h"
|
||||||
#include "sentence.h"
|
#include "sentence.h"
|
||||||
|
#include "sentences.h"
|
||||||
|
|
||||||
namespace amunmt {
|
namespace amunmt {
|
||||||
|
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
namespace amunmt {
|
namespace amunmt {
|
||||||
|
|
||||||
|
class Sentences;
|
||||||
|
|
||||||
class History {
|
class History {
|
||||||
private:
|
private:
|
||||||
struct HypothesisCoord {
|
struct HypothesisCoord {
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
namespace amunmt {
|
namespace amunmt {
|
||||||
|
|
||||||
class God;
|
class God;
|
||||||
|
class Sentences;
|
||||||
|
|
||||||
class State {
|
class State {
|
||||||
public:
|
public:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#include "common/search.h"
|
|
||||||
|
|
||||||
#include <boost/timer/timer.hpp>
|
#include <boost/timer/timer.hpp>
|
||||||
|
#include "common/search.h"
|
||||||
|
#include "common/sentences.h"
|
||||||
#include "common/god.h"
|
#include "common/god.h"
|
||||||
#include "common/history.h"
|
#include "common/history.h"
|
||||||
#include "common/filter.h"
|
#include "common/filter.h"
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include "sentence.h"
|
#include "sentence.h"
|
||||||
#include "god.h"
|
#include "god.h"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
@ -51,47 +50,6 @@ const Words& Sentence::GetWords(size_t index) const {
|
|||||||
return words_[index];
|
return words_[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////
|
|
||||||
Sentences::Sentences()
|
|
||||||
: maxLength_(0)
|
|
||||||
{}
|
|
||||||
|
|
||||||
Sentences::~Sentences()
|
|
||||||
{}
|
|
||||||
|
|
||||||
void Sentences::push_back(SentencePtr sentence) {
|
|
||||||
const Words &words = sentence->GetWords(0);
|
|
||||||
size_t len = words.size();
|
|
||||||
if (len > maxLength_) {
|
|
||||||
maxLength_ = len;
|
|
||||||
}
|
|
||||||
|
|
||||||
coll_.push_back(sentence);
|
|
||||||
}
|
|
||||||
|
|
||||||
class LengthOrderer {
|
|
||||||
public:
|
|
||||||
bool operator()(const SentencePtr& a, const SentencePtr& b) const {
|
|
||||||
return a->GetWords(0).size() < b->GetWords(0).size();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void Sentences::SortByLength() {
|
|
||||||
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
|
|
||||||
}
|
|
||||||
|
|
||||||
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
|
|
||||||
{
|
|
||||||
SentencesPtr sentences(new Sentences());
|
|
||||||
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
|
|
||||||
for (size_t i = startInd; i < size(); ++i) {
|
|
||||||
SentencePtr sentence = at(i);
|
|
||||||
sentences->push_back(sentence);
|
|
||||||
}
|
|
||||||
|
|
||||||
coll_.resize(startInd);
|
|
||||||
return sentences;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,39 +27,6 @@ class Sentence {
|
|||||||
|
|
||||||
using SentencePtr = std::shared_ptr<Sentence>;
|
using SentencePtr = std::shared_ptr<Sentence>;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////
|
|
||||||
class Sentences;
|
|
||||||
using SentencesPtr = std::shared_ptr<Sentences>;
|
|
||||||
|
|
||||||
class Sentences {
|
|
||||||
public:
|
|
||||||
Sentences();
|
|
||||||
~Sentences();
|
|
||||||
|
|
||||||
void push_back(SentencePtr sentence);
|
|
||||||
|
|
||||||
SentencePtr at(size_t id) const {
|
|
||||||
return coll_.at(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t size() const {
|
|
||||||
return coll_.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t GetMaxLength() const {
|
|
||||||
return maxLength_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SortByLength();
|
|
||||||
|
|
||||||
SentencesPtr NextMiniBatch(size_t batchsize);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
std::vector<SentencePtr> coll_;
|
|
||||||
size_t maxLength_;
|
|
||||||
|
|
||||||
Sentences(const Sentences &) = delete;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
48
src/common/sentences.cpp
Normal file
48
src/common/sentences.cpp
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#include <algorithm>
|
||||||
|
#include "sentences.h"
|
||||||
|
|
||||||
|
namespace amunmt {
|
||||||
|
|
||||||
|
Sentences::Sentences()
|
||||||
|
: maxLength_(0)
|
||||||
|
{}
|
||||||
|
|
||||||
|
Sentences::~Sentences()
|
||||||
|
{}
|
||||||
|
|
||||||
|
void Sentences::push_back(SentencePtr sentence) {
|
||||||
|
const Words &words = sentence->GetWords(0);
|
||||||
|
size_t len = words.size();
|
||||||
|
if (len > maxLength_) {
|
||||||
|
maxLength_ = len;
|
||||||
|
}
|
||||||
|
|
||||||
|
coll_.push_back(sentence);
|
||||||
|
}
|
||||||
|
|
||||||
|
class LengthOrderer {
|
||||||
|
public:
|
||||||
|
bool operator()(const SentencePtr& a, const SentencePtr& b) const {
|
||||||
|
return a->GetWords(0).size() < b->GetWords(0).size();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void Sentences::SortByLength() {
|
||||||
|
std::sort(coll_.rbegin(), coll_.rend(), LengthOrderer());
|
||||||
|
}
|
||||||
|
|
||||||
|
SentencesPtr Sentences::NextMiniBatch(size_t batchsize)
|
||||||
|
{
|
||||||
|
SentencesPtr sentences(new Sentences());
|
||||||
|
size_t startInd = (batchsize > size()) ? 0 : size() - batchsize;
|
||||||
|
for (size_t i = startInd; i < size(); ++i) {
|
||||||
|
SentencePtr sentence = at(i);
|
||||||
|
sentences->push_back(sentence);
|
||||||
|
}
|
||||||
|
|
||||||
|
coll_.resize(startInd);
|
||||||
|
return sentences;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
40
src/common/sentences.h
Normal file
40
src/common/sentences.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#pragma once
|
||||||
|
#include "sentence.h"
|
||||||
|
|
||||||
|
namespace amunmt {
|
||||||
|
|
||||||
|
class Sentences;
|
||||||
|
using SentencesPtr = std::shared_ptr<Sentences>;
|
||||||
|
|
||||||
|
class Sentences {
|
||||||
|
public:
|
||||||
|
Sentences();
|
||||||
|
~Sentences();
|
||||||
|
|
||||||
|
void push_back(SentencePtr sentence);
|
||||||
|
|
||||||
|
SentencePtr at(size_t id) const {
|
||||||
|
return coll_.at(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return coll_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t GetMaxLength() const {
|
||||||
|
return maxLength_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SortByLength();
|
||||||
|
|
||||||
|
SentencesPtr NextMiniBatch(size_t batchsize);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<SentencePtr> coll_;
|
||||||
|
size_t maxLength_;
|
||||||
|
|
||||||
|
Sentences(const Sentences &) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -12,6 +12,7 @@
|
|||||||
#include "common/loader.h"
|
#include "common/loader.h"
|
||||||
#include "common/scorer.h"
|
#include "common/scorer.h"
|
||||||
#include "common/sentence.h"
|
#include "common/sentence.h"
|
||||||
|
#include "common/sentences.h"
|
||||||
|
|
||||||
#include "cpu/mblas/matrix.h"
|
#include "cpu/mblas/matrix.h"
|
||||||
#include "cpu/decoder/best_hyps.h"
|
#include "cpu/decoder/best_hyps.h"
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include "../mblas/matrix.h"
|
#include "../mblas/matrix.h"
|
||||||
#include "../dl4mt/model.h"
|
#include "../dl4mt/model.h"
|
||||||
#include "../dl4mt/gru.h"
|
#include "../dl4mt/gru.h"
|
||||||
|
|
||||||
namespace amunmt {
|
namespace amunmt {
|
||||||
namespace CPU {
|
namespace CPU {
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "common/god.h"
|
#include "common/god.h"
|
||||||
|
#include "common/sentences.h"
|
||||||
|
|
||||||
#include "encoder_decoder.h"
|
#include "encoder_decoder.h"
|
||||||
#include "gpu/mblas/matrix_functions.h"
|
#include "gpu/mblas/matrix_functions.h"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include "encoder.h"
|
#include "encoder.h"
|
||||||
|
#include "common/sentences.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
@ -7,6 +7,9 @@
|
|||||||
#include "gpu/types-gpu.h"
|
#include "gpu/types-gpu.h"
|
||||||
|
|
||||||
namespace amunmt {
|
namespace amunmt {
|
||||||
|
|
||||||
|
class Sentences;
|
||||||
|
|
||||||
namespace GPU {
|
namespace GPU {
|
||||||
|
|
||||||
class Encoder {
|
class Encoder {
|
||||||
|
Loading…
Reference in New Issue
Block a user