mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
APE penalty
This commit is contained in:
parent
c2059610f5
commit
2bf3026097
@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
|
||||
|
||||
project(amunn CXX)
|
||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;)
|
||||
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated")
|
||||
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;)
|
||||
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
|
||||
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
|
||||
|
59
src/decoder/ape_penalty.h
Normal file
59
src/decoder/ape_penalty.h
Normal file
@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "types.h"
|
||||
#include "scorer.h"
|
||||
#include "matrix.h"
|
||||
|
||||
class ApePenaltyState : public State {
|
||||
// Dummy
|
||||
};
|
||||
|
||||
class ApePenalty : public Scorer {
|
||||
|
||||
public:
|
||||
ApePenalty(size_t sourceIndex)
|
||||
: Scorer(sourceIndex)
|
||||
{ }
|
||||
|
||||
virtual void SetSource(const Sentence& source) {
|
||||
const Words& words = source.GetWords(sourceIndex_);
|
||||
const Vocab& svcb = God::GetSourceVocab(sourceIndex_);
|
||||
const Vocab& tvcb = God::GetTargetVocab();
|
||||
|
||||
costs_.clear();
|
||||
costs_.resize(tvcb.size(), -1.0);
|
||||
for(auto& s : words) {
|
||||
const std::string& sstr = svcb[s];
|
||||
Word t = tvcb[sstr];
|
||||
if(t != UNK && t < costs_.size())
|
||||
costs_[t] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void Score(const State& in,
|
||||
Prob& prob,
|
||||
State& out) {
|
||||
size_t cols = prob.Cols();
|
||||
for(size_t i = 0; i < prob.Rows(); ++i)
|
||||
algo::copy(costs_.begin(), costs_.begin() + cols, prob.begin() + i * cols);
|
||||
}
|
||||
|
||||
virtual State* NewState() {
|
||||
return new ApePenaltyState();
|
||||
}
|
||||
|
||||
virtual void BeginSentenceState(State& state) { }
|
||||
|
||||
virtual void AssembleBeamState(const State& in,
|
||||
const Beam& beam,
|
||||
State& out) { }
|
||||
|
||||
virtual size_t GetVocabSize() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<float> costs_;
|
||||
};
|
@ -6,6 +6,7 @@
|
||||
#include "threadpool.h"
|
||||
#include "encoder_decoder.h"
|
||||
#include "language_model.h"
|
||||
#include "ape_penalty.h"
|
||||
|
||||
God God::instance_;
|
||||
|
||||
@ -45,6 +46,8 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
"Path to source vocabulary file.")
|
||||
("target,t", po::value(&targetVocabPath)->required(),
|
||||
"Path to target vocabulary file.")
|
||||
("ape", po::value<bool>()->zero_tokens()->default_value(false),
|
||||
"Add APE-penalty")
|
||||
("lm,l", po::value(&lmPaths)->multitoken(),
|
||||
"Path to KenLM language model(s)")
|
||||
("tab-map", po::value(&tabMap_)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
|
||||
@ -124,11 +127,17 @@ God& God::NonStaticInit(int argc, char** argv) {
|
||||
tabMap_.resize(modelPaths.size(), 0);
|
||||
}
|
||||
|
||||
// @TODO: handle this better!
|
||||
if(weights_.size() < modelPaths.size()) {
|
||||
// this should be a warning
|
||||
LOG(info) << "More neural models than weights, setting weights to 1.0";
|
||||
weights_.resize(modelPaths.size(), 1.0);
|
||||
}
|
||||
|
||||
if(Get<bool>("ape") && weights_.size() < modelPaths.size() + 1) {
|
||||
LOG(info) << "Adding weight for APE-penalty: " << 1.0;
|
||||
weights_.resize(modelPaths.size(), 1.0);
|
||||
}
|
||||
|
||||
if(weights_.size() < modelPaths.size() + lmPaths.size()) {
|
||||
// this should be a warning
|
||||
@ -186,6 +195,8 @@ std::vector<ScorerPtr> God::GetScorers(size_t threadId) {
|
||||
size_t i = 0;
|
||||
for(auto& m : Summon().modelsPerDevice_[deviceId])
|
||||
scorers.emplace_back(new EncoderDecoder(*m, Summon().tabMap_[i++]));
|
||||
if(God::Get<bool>("ape"))
|
||||
scorers.emplace_back(new ApePenalty(Summon().tabMap_[i++]));
|
||||
for(auto& lm : Summon().lms_)
|
||||
scorers.emplace_back(new LanguageModel(lm));
|
||||
return scorers;
|
||||
|
Loading…
Reference in New Issue
Block a user