APE penalty

This commit is contained in:
Marcin Junczys-Dowmunt 2016-04-27 12:47:47 +02:00
parent c2059610f5
commit 2bf3026097
3 changed files with 72 additions and 2 deletions

View File

@ -2,8 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
project(amunn CXX)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O3 -funroll-loops -Wno-unused-result -Wno-deprecated")
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O3; -arch=sm_35; -lineinfo; --use_fast_math;)
SET(CMAKE_CXX_FLAGS " -std=c++11 -g -O0 -funroll-loops -Wno-unused-result -Wno-deprecated")
LIST(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -std=c++11; -g; -O0; -arch=sm_35; -lineinfo; --use_fast_math;)
add_definitions(-DCUDA_API_PER_THREAD_DEFAULT_STREAM)
SET(CUDA_PROPAGATE_HOST_FLAGS OFF)

59
src/decoder/ape_penalty.h Normal file
View File

@ -0,0 +1,59 @@
#pragma once
#include <vector>
#include "types.h"
#include "scorer.h"
#include "matrix.h"
class ApePenaltyState : public State {
// Dummy
};
class ApePenalty : public Scorer {
public:
ApePenalty(size_t sourceIndex)
: Scorer(sourceIndex)
{ }
virtual void SetSource(const Sentence& source) {
const Words& words = source.GetWords(sourceIndex_);
const Vocab& svcb = God::GetSourceVocab(sourceIndex_);
const Vocab& tvcb = God::GetTargetVocab();
costs_.clear();
costs_.resize(tvcb.size(), -1.0);
for(auto& s : words) {
const std::string& sstr = svcb[s];
Word t = tvcb[sstr];
if(t != UNK && t < costs_.size())
costs_[t] = 0.0;
}
}
virtual void Score(const State& in,
Prob& prob,
State& out) {
size_t cols = prob.Cols();
for(size_t i = 0; i < prob.Rows(); ++i)
algo::copy(costs_.begin(), costs_.begin() + cols, prob.begin() + i * cols);
}
virtual State* NewState() {
return new ApePenaltyState();
}
virtual void BeginSentenceState(State& state) { }
virtual void AssembleBeamState(const State& in,
const Beam& beam,
State& out) { }
virtual size_t GetVocabSize() const {
return 0;
}
private:
std::vector<float> costs_;
};

View File

@ -6,6 +6,7 @@
#include "threadpool.h"
#include "encoder_decoder.h"
#include "language_model.h"
#include "ape_penalty.h"
God God::instance_;
@ -45,6 +46,8 @@ God& God::NonStaticInit(int argc, char** argv) {
"Path to source vocabulary file.")
("target,t", po::value(&targetVocabPath)->required(),
"Path to target vocabulary file.")
("ape", po::value<bool>()->zero_tokens()->default_value(false),
"Add APE-penalty")
("lm,l", po::value(&lmPaths)->multitoken(),
"Path to KenLM language model(s)")
("tab-map", po::value(&tabMap_)->multitoken()->default_value(std::vector<size_t>(1, 0), "0"),
@ -124,11 +127,17 @@ God& God::NonStaticInit(int argc, char** argv) {
tabMap_.resize(modelPaths.size(), 0);
}
// @TODO: handle this better!
if(weights_.size() < modelPaths.size()) {
// this should be a warning
LOG(info) << "More neural models than weights, setting weights to 1.0";
weights_.resize(modelPaths.size(), 1.0);
}
if(Get<bool>("ape") && weights_.size() < modelPaths.size() + 1) {
LOG(info) << "Adding weight for APE-penalty: " << 1.0;
weights_.resize(modelPaths.size(), 1.0);
}
if(weights_.size() < modelPaths.size() + lmPaths.size()) {
// this should be a warning
@ -186,6 +195,8 @@ std::vector<ScorerPtr> God::GetScorers(size_t threadId) {
size_t i = 0;
for(auto& m : Summon().modelsPerDevice_[deviceId])
scorers.emplace_back(new EncoderDecoder(*m, Summon().tabMap_[i++]));
if(God::Get<bool>("ape"))
scorers.emplace_back(new ApePenalty(Summon().tabMap_[i++]));
for(auto& lm : Summon().lms_)
scorers.emplace_back(new LanguageModel(lm));
return scorers;