From e6cb99c48ca1a8eb84a428ad7e3532e6c02de9b6 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Sun, 1 May 2016 11:23:55 +0200 Subject: [PATCH] better file handling, better exceptions hanling --- CMakeLists.txt | 18 +++++------ scripts/idf.py | 9 ++++-- src/CMakeLists.txt | 2 +- src/common/exception.cpp | 10 ++++-- src/common/exception.h | 6 ++-- src/common/file_stream.h | 42 +++++++++++++++++++++++++ src/common/vocab.h | 8 ++--- src/decoder/ape_penalty.h | 59 ++++++++++++++++++++++++++++------- src/decoder/config.cpp | 3 +- src/decoder/encoder_decoder.h | 12 ++++--- src/decoder/god.cu | 3 +- src/decoder/god.h | 3 -- src/decoder/scorer.h | 11 ++++--- 13 files changed, 136 insertions(+), 50 deletions(-) create mode 100644 src/common/file_stream.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8b939813..9cd42916 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,7 @@ SET(CUDA_PROPAGATE_HOST_FLAGS OFF) include_directories(${amunn_SOURCE_DIR}) find_package(CUDA REQUIRED) -find_package(Boost COMPONENTS system filesystem program_options timer) +find_package(Boost COMPONENTS system filesystem program_options timer iostreams) if(Boost_FOUND) include_directories(${Boost_INCLUDE_DIRS}) set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES}) @@ -24,15 +24,15 @@ if (YAMLCPP_FOUND) set(EXT_LIBS ${EXT_LIBS} ${YAMLCPP_LIBRARY}) endif (YAMLCPP_FOUND) -set(KENLM CACHE STRING "Path to compiled kenlm directory") -if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a") - message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a") -endif() +#set(KENLM CACHE STRING "Path to compiled kenlm directory") +#if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a") +# message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a") +#endif() -set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm.a) -set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm_util.a) -include_directories(${KENLM}) -add_definitions(-DKENLM_MAX_ORDER=6) +#set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm.a) +#set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm_util.a) +#include_directories(${KENLM}) +#add_definitions(-DKENLM_MAX_ORDER=6) find_package (BZip2) if (BZIP2_FOUND) diff --git a/scripts/idf.py b/scripts/idf.py index efd05e9a..60905d80 100644 --- a/scripts/idf.py +++ b/scripts/idf.py @@ -1,5 +1,6 @@ import sys import math +import yaml from collections import Counter c = Counter() @@ -10,7 +11,9 @@ for line in sys.stdin: c[word] += 1 N += 1 -keys = sorted([k for k in c]) -for word in keys: +out = dict() +for word in c: idf = math.log(float(N) / float(c[word])) / math.log(N) - print word, ":", idf + out[word] = idf + +yaml.safe_dump(out, sys.stdout, default_flow_style=False, allow_unicode=True) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a42b9028..d2deb379 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,7 +17,7 @@ add_library(librescorer OBJECT add_library(libamunn OBJECT decoder/config.cpp - decoder/kenlm.cpp +# decoder/kenlm.cpp ) cuda_add_executable( diff --git a/src/common/exception.cpp b/src/common/exception.cpp index 01ff9a67..453fcf66 100644 --- a/src/common/exception.cpp +++ b/src/common/exception.cpp @@ -1,4 +1,4 @@ -#include "util/exception.hh" +#include "exception.h" #ifdef __GXX_RTTI #include @@ -17,14 +17,18 @@ namespace util { Exception::Exception() throw() {} Exception::~Exception() throw() {} +Exception::Exception(const Exception& o) throw() { + what_.str(o.what_.str()); +} + void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) { /* The child class might have set some text, but we want this to come first. * Another option would be passing this information to the constructor, but * then child classes would have to accept constructor arguments and pass * them down. */ - std::string old_text; - what_.swap(old_text); + std::string old_text = what_.str(); + what_.str(std::string()); what_ << file << ':' << line; if (func) what_ << " in " << func << " threw "; if (child_name) { diff --git a/src/common/exception.h b/src/common/exception.h index 0bfd73a6..85827d8c 100644 --- a/src/common/exception.h +++ b/src/common/exception.h @@ -1,7 +1,6 @@ #pragma once -#include "util/string_stream.hh" - +#include #include #include #include @@ -15,6 +14,7 @@ class Exception : public std::exception { public: Exception() throw(); virtual ~Exception() throw(); + Exception(const Exception& o) throw(); const char *what() const throw() { return what_.str().c_str(); } @@ -34,7 +34,7 @@ class Exception : public std::exception { typedef T Identity; }; - StringStream what_; + std::stringstream what_; }; /* This implements the normal operator<< for Exception and all its children. diff --git a/src/common/file_stream.h b/src/common/file_stream.h new file mode 100644 index 00000000..bff84355 --- /dev/null +++ b/src/common/file_stream.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "exception.h" + +class InputFileStream { + public: + InputFileStream(const std::string& file) + : file_(file), ifstream_(file_) + { + UTIL_THROW_IF2(!boost::filesystem::exists(file_), + "File " << file << " does not exist"); + + if(file_.extension() == ".gz") + istream_.push(boost::iostreams::gzip_decompressor()); + istream_.push(ifstream_); + } + + operator std::istream& () { + return istream_; + } + + operator bool () { + return istream_; + } + + template + friend InputFileStream& operator>>(InputFileStream& stream, T& t) { + stream.istream_ >> t; + return stream; + } + + private: + boost::filesystem::path file_; + boost::filesystem::ifstream ifstream_; + boost::iostreams::filtering_istream istream_; +}; diff --git a/src/common/vocab.h b/src/common/vocab.h index 5994f2f9..04b014a0 100644 --- a/src/common/vocab.h +++ b/src/common/vocab.h @@ -3,19 +3,19 @@ #include #include #include -#include #include #include "types.h" #include "utils.h" +#include "file_stream.h" class Vocab { public: - Vocab(const std::string& txt) { - std::ifstream in(txt.c_str()); + Vocab(const std::string& path) { + InputFileStream in(path); size_t c = 0; std::string line; - while(std::getline(in, line)) { + while(std::getline((std::istream&)in, line)) { str2id_[line] = c++; id2str_.push_back(line); } diff --git a/src/decoder/ape_penalty.h b/src/decoder/ape_penalty.h index c64a31c9..8723d1a3 100644 --- a/src/decoder/ape_penalty.h +++ b/src/decoder/ape_penalty.h @@ -3,36 +3,48 @@ #include #include "types.h" +#include "file_stream.h" #include "scorer.h" #include "matrix.h" +typedef std::vector SrcTrgMap; +typedef std::vector Penalties; + class ApePenaltyState : public State { // Dummy, this scorer is stateless }; class ApePenalty : public Scorer { + private: + const SrcTrgMap& srcTrgMap_; + const Penalties& penalties_; public: - ApePenalty(size_t sourceIndex) - : Scorer(sourceIndex) + ApePenalty( + const SrcTrgMap& srcTrgMap, + const Penalties& penalties, + const YAML::Node& config, + size_t tab) + : Scorer(config, tab), srcTrgMap_(srcTrgMap), + penalties_(penalties) { } // @TODO: make this work on GPU virtual void SetSource(const Sentence& source) { - const Words& words = source.GetWords(sourceIndex_); - const Vocab& svcb = God::GetSourceVocab(sourceIndex_); - const Vocab& tvcb = God::GetTargetVocab(); + const Words& words = source.GetWords(tab_); costs_.clear(); - costs_.resize(tvcb.size(), -1.0); - for(auto& s : words) { - const std::string& sstr = svcb[s]; - Word t = tvcb[sstr]; + costs_.resize(penalties_.size()); + algo::copy(penalties_.begin(), penalties_.end(), costs_.begin()); + + for(auto&& s : words) { + Word t = srcTrgMap_[s]; if(t != UNK && t < costs_.size()) costs_[t] = 0.0; } } + // @TODO: make this work on GPU virtual void Score(const State& in, Prob& prob, State& out) { @@ -65,11 +77,34 @@ class ApePenaltyLoader : public Loader { : Loader(config) {} virtual void Load() { - // @TODO: IDF weights + size_t tab = Has("tab") ? Get("tab") : 0; + const Vocab& svcb = God::GetSourceVocab(tab); + const Vocab& tvcb = God::GetTargetVocab(); + + srcTrgMap_.resize(svcb.size(), UNK); + for(Word s = 0; s < svcb.size(); ++s) + srcTrgMap_[s] = tvcb[svcb[s]]; + + penalties_.resize(tvcb.size(), -1.0); + + if(Has("path")) { + LOG(info) << "Loading APE penalties from " << Get("path"); + YAML::Node penalties = YAML::Load(InputFileStream(Get("path"))); + for(auto&& pair : penalties) { + std::string entry = pair.first.as(); + float penalty = pair.second.as(); + penalties_[tvcb[entry]] = -penalty; + } + } } virtual ScorerPtr NewScorer(size_t taskId) { size_t tab = Has("tab") ? Get("tab") : 0; - return ScorerPtr(new ApePenalty(tab)); + return ScorerPtr(new ApePenalty(srcTrgMap_, penalties_, + config_, tab)); } -}; \ No newline at end of file + + private: + SrcTrgMap srcTrgMap_; + Penalties penalties_; +}; diff --git a/src/decoder/config.cpp b/src/decoder/config.cpp index 6fdebe5a..1278dba7 100644 --- a/src/decoder/config.cpp +++ b/src/decoder/config.cpp @@ -1,6 +1,7 @@ #include #include "config.h" +#include "file_stream.h" #include "exception.h" #define SET_OPTION(key, type) \ @@ -79,7 +80,7 @@ void Config::AddOptions(size_t argc, char** argv) { exit(0); } - config_ = YAML::LoadFile(configPath); + config_ = YAML::Load(InputFileStream(configPath)); SET_OPTION("n-best", bool) SET_OPTION("normalize", bool) diff --git a/src/decoder/encoder_decoder.h b/src/decoder/encoder_decoder.h index 92d4985f..1dac4e4d 100644 --- a/src/decoder/encoder_decoder.h +++ b/src/decoder/encoder_decoder.h @@ -38,8 +38,10 @@ class EncoderDecoder : public Scorer { typedef EncoderDecoderState EDState; public: - EncoderDecoder(const Weights& model, size_t tabIndex) - : Scorer(tabIndex), model_(model), + EncoderDecoder(const Weights& model, + const YAML::Node& config, + size_t tab) + : Scorer(config, tab), model_(model), encoder_(new Encoder(model_)), decoder_(new Decoder(model_)) {} @@ -65,7 +67,7 @@ class EncoderDecoder : public Scorer { } virtual void SetSource(const Sentence& source) { - encoder_->GetContext(source.GetWords(sourceIndex_), + encoder_->GetContext(source.GetWords(tab_), SourceContext_); } @@ -107,7 +109,7 @@ class EncoderDecoderLoader : public Loader { public: EncoderDecoderLoader(const YAML::Node& config) : Loader(config) {} - + virtual void Load() { std::string path = Get("path"); auto devices = God::Get>("devices"); @@ -126,7 +128,7 @@ class EncoderDecoderLoader : public Loader { size_t d = weights_[i]->GetDevice(); cudaSetDevice(d); size_t tab = Has("tab") ? Get("tab") : 0; - return ScorerPtr(new EncoderDecoder(*weights_[i], tab)); + return ScorerPtr(new EncoderDecoder(*weights_[i], config_, tab)); } private: diff --git a/src/decoder/god.cu b/src/decoder/god.cu index eb10dce3..d2138459 100644 --- a/src/decoder/god.cu +++ b/src/decoder/god.cu @@ -7,6 +7,7 @@ #include "config.h" #include "scorer.h" #include "threadpool.h" +#include "file_stream.h" #include "loader_factory.h" God God::instance_; @@ -76,7 +77,7 @@ void God::CleanUp() { void God::LoadWeights(const std::string& path) { LOG(info) << "Reading weights from " << path; - std::ifstream fweights(path.c_str()); + InputFileStream fweights(path); std::string name; float weight; size_t i = 0; diff --git a/src/decoder/god.h b/src/decoder/god.h index 6110f2a7..70c15c9a 100644 --- a/src/decoder/god.h +++ b/src/decoder/god.h @@ -7,9 +7,6 @@ #include "loader.h" #include "scorer.h" #include "logging.h" - -// this should not be here -#include "kenlm.h" class Weights; diff --git a/src/decoder/scorer.h b/src/decoder/scorer.h index ef6671a2..ea139938 100644 --- a/src/decoder/scorer.h +++ b/src/decoder/scorer.h @@ -28,8 +28,8 @@ typedef std::vector Probs; class Scorer { public: - Scorer() : sourceIndex_(0) {} - Scorer(size_t sourceIndex) : sourceIndex_(sourceIndex) {} + Scorer(const YAML::Node& config, size_t tab) + : config_(config), tab_(tab) {} virtual ~Scorer() {} @@ -52,13 +52,14 @@ class Scorer { virtual void CleanUpAfterSentence() {} protected: - size_t sourceIndex_; + const YAML::Node& config_; + size_t tab_; }; class SourceIndependentScorer : public Scorer { public: - SourceIndependentScorer() : Scorer(0) {} - SourceIndependentScorer(size_t) : Scorer(0) {} + SourceIndependentScorer(const YAML::Node& config, size_t) + : Scorer(config, 0) {} virtual ~SourceIndependentScorer() {}