mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
APE penalty
This commit is contained in:
parent
c2059610f5
commit
2bf3026097
@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
|
|||||||
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
||||||
|
|
||||||
project(amunn CXX)
|
project(amunn CXX)
|
||||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||||
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;)
|
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;)
|
||||||
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
|
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
|
||||||
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||||
|
|
||||||
|
59
src/decoder/ape_penalty.h
Normal file
59
src/decoder/ape_penalty.h
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
#include "scorer.h"
|
||||||
|
#include "matrix.h"
|
||||||
|
|
||||||
|
class ApePenaltyState : public State {
|
||||||
|
// Dummy
|
||||||
|
};
|
||||||
|
|
||||||
|
class ApePenalty : public Scorer {
|
||||||
|
|
||||||
|
public:
|
||||||
|
ApePenalty(size_t sourceIndex)
|
||||||
|
: Scorer(sourceIndex)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
virtual void SetSource(const Sentence& source) {
|
||||||
|
const Words& words = source.GetWords(sourceIndex_);
|
||||||
|
const Vocab& svcb = God::GetSourceVocab(sourceIndex_);
|
||||||
|
const Vocab& tvcb = God::GetTargetVocab();
|
||||||
|
|
||||||
|
costs_.clear();
|
||||||
|
costs_.resize(tvcb.size(), -1.0);
|
||||||
|
for(auto& s : words) {
|
||||||
|
const std::string& sstr = svcb[s];
|
||||||
|
Word t = tvcb[sstr];
|
||||||
|
if(t != UNK && t < costs_.size())
|
||||||
|
costs_[t] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void Score(const State& in,
|
||||||
|
Prob& prob,
|
||||||
|
State& out) {
|
||||||
|
size_t cols = prob.Cols();
|
||||||
|
for(size_t i = 0; i < prob.Rows(); ++i)
|
||||||
|
algo::copy(costs_.begin(), costs_.begin() + cols, prob.begin() + i * cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual State* NewState() {
|
||||||
|
return new ApePenaltyState();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void BeginSentenceState(State& state) { }
|
||||||
|
|
||||||
|
virtual void AssembleBeamState(const State& in,
|
||||||
|
const Beam& beam,
|
||||||
|
State& out) { }
|
||||||
|
|
||||||
|
virtual size_t GetVocabSize() const {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<float> costs_;
|
||||||
|
};
|
@ -6,6 +6,7 @@
|
|||||||
#include "threadpool.h"
|
#include "threadpool.h"
|
||||||
#include "encoder_decoder.h"
|
#include "encoder_decoder.h"
|
||||||
#include "language_model.h"
|
#include "language_model.h"
|
||||||
|
#include "ape_penalty.h"
|
||||||
|
|
||||||
God God::instance_;
|
God God::instance_;
|
||||||
|
|
||||||
@ -45,6 +46,8 @@ God& God::NonStaticInit(int argc, char** argv) {
|
|||||||
"Path to source vocabulary file.")
|
"Path to source vocabulary file.")
|
||||||
("target,t", po::value(&targetVocabPath)->required(),
|
("target,t", po::value(&targetVocabPath)->required(),
|
||||||
"Path to target vocabulary file.")
|
"Path to target vocabulary file.")
|
||||||
|
("ape", po::value<bool>()->zero_tokens()->default_value(false),
|
||||||
|
"Add APE-penalty")
|
||||||
("lm,l", po::value(&lmPaths)->multitoken(),
|
("lm,l", po::value(&lmPaths)->multitoken(),
|
||||||
"Path to KenLM language model(s)")
|
"Path to KenLM language model(s)")
|
||||||
("tab-map", po::value(&tabMap_)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
("tab-map", po::value(&tabMap_)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||||
@ -124,11 +127,17 @@ God& God::NonStaticInit(int argc, char** argv) {
|
|||||||
tabMap_.resize(modelPaths.size(), 0);
|
tabMap_.resize(modelPaths.size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @TODO: handle this better!
|
||||||
if(weights_.size() < modelPaths.size()) {
|
if(weights_.size() < modelPaths.size()) {
|
||||||
// this should be a warning
|
// this should be a warning
|
||||||
LOG(info) << "More neural models than weights, setting weights to 1.0";
|
LOG(info) << "More neural models than weights, setting weights to 1.0";
|
||||||
weights_.resize(modelPaths.size(), 1.0);
|
weights_.resize(modelPaths.size(), 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(Get<bool>("ape") && weights_.size() < modelPaths.size() + 1) {
|
||||||
|
LOG(info) << "Adding weight for APE-penalty: " << 1.0;
|
||||||
|
weights_.resize(modelPaths.size(), 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
if(weights_.size() < modelPaths.size() + lmPaths.size()) {
|
if(weights_.size() < modelPaths.size() + lmPaths.size()) {
|
||||||
// this should be a warning
|
// this should be a warning
|
||||||
@ -186,6 +195,8 @@ std::vector<ScorerPtr> God::GetScorers(size_t threadId) {
|
|||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for(auto& m : Summon().modelsPerDevice_[deviceId])
|
for(auto& m : Summon().modelsPerDevice_[deviceId])
|
||||||
scorers.emplace_back(new EncoderDecoder(*m, Summon().tabMap_[i++]));
|
scorers.emplace_back(new EncoderDecoder(*m, Summon().tabMap_[i++]));
|
||||||
|
if(God::Get<bool>("ape"))
|
||||||
|
scorers.emplace_back(new ApePenalty(Summon().tabMap_[i++]));
|
||||||
for(auto& lm : Summon().lms_)
|
for(auto& lm : Summon().lms_)
|
||||||
scorers.emplace_back(new LanguageModel(lm));
|
scorers.emplace_back(new LanguageModel(lm));
|
||||||
return scorers;
|
return scorers;
|
||||||
|
Loading…
Reference in New Issue
Block a user