mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
better file handling, better exceptions hanling
This commit is contained in:
parent
b80ec4fa04
commit
e6cb99c48c
@ -10,7 +10,7 @@ SET(CUDA_PROPAGATE_HOST_FLAGS OFF)
|
||||
include_directories(${amunn_SOURCE_DIR})
|
||||
find_package(CUDA REQUIRED)
|
||||
|
||||
find_package(Boost COMPONENTS system filesystem program_options timer)
|
||||
find_package(Boost COMPONENTS system filesystem program_options timer iostreams)
|
||||
if(Boost_FOUND)
|
||||
include_directories(${Boost_INCLUDE_DIRS})
|
||||
set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})
|
||||
@ -24,15 +24,15 @@ if (YAMLCPP_FOUND)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${YAMLCPP_LIBRARY})
|
||||
endif (YAMLCPP_FOUND)
|
||||
|
||||
set(KENLM CACHE STRING "Path to compiled kenlm directory")
|
||||
if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a")
|
||||
message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a")
|
||||
endif()
|
||||
#set(KENLM CACHE STRING "Path to compiled kenlm directory")
|
||||
#if (NOT EXISTS "${KENLM}/build/lib/libkenlm.a")
|
||||
# message(FATAL_ERROR "Could not find ${KENLM}/build/lib/libkenlm.a")
|
||||
#endif()
|
||||
|
||||
set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm.a)
|
||||
set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm_util.a)
|
||||
include_directories(${KENLM})
|
||||
add_definitions(-DKENLM_MAX_ORDER=6)
|
||||
#set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm.a)
|
||||
#set(EXT_LIBS ${EXT_LIBS} ${KENLM}/build/lib/libkenlm_util.a)
|
||||
#include_directories(${KENLM})
|
||||
#add_definitions(-DKENLM_MAX_ORDER=6)
|
||||
|
||||
find_package (BZip2)
|
||||
if (BZIP2_FOUND)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import math
|
||||
import yaml
|
||||
from collections import Counter
|
||||
|
||||
c = Counter()
|
||||
@ -10,7 +11,9 @@ for line in sys.stdin:
|
||||
c[word] += 1
|
||||
N += 1
|
||||
|
||||
keys = sorted([k for k in c])
|
||||
for word in keys:
|
||||
out = dict()
|
||||
for word in c:
|
||||
idf = math.log(float(N) / float(c[word])) / math.log(N)
|
||||
print word, ":", idf
|
||||
out[word] = idf
|
||||
|
||||
yaml.safe_dump(out, sys.stdout, default_flow_style=False, allow_unicode=True)
|
@ -17,7 +17,7 @@ add_library(librescorer OBJECT
|
||||
|
||||
add_library(libamunn OBJECT
|
||||
decoder/config.cpp
|
||||
decoder/kenlm.cpp
|
||||
# decoder/kenlm.cpp
|
||||
)
|
||||
|
||||
cuda_add_executable(
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "util/exception.hh"
|
||||
#include "exception.h"
|
||||
|
||||
#ifdef __GXX_RTTI
|
||||
#include <typeinfo>
|
||||
@ -17,14 +17,18 @@ namespace util {
|
||||
Exception::Exception() throw() {}
|
||||
Exception::~Exception() throw() {}
|
||||
|
||||
Exception::Exception(const Exception& o) throw() {
|
||||
what_.str(o.what_.str());
|
||||
}
|
||||
|
||||
void Exception::SetLocation(const char *file, unsigned int line, const char *func, const char *child_name, const char *condition) {
|
||||
/* The child class might have set some text, but we want this to come first.
|
||||
* Another option would be passing this information to the constructor, but
|
||||
* then child classes would have to accept constructor arguments and pass
|
||||
* them down.
|
||||
*/
|
||||
std::string old_text;
|
||||
what_.swap(old_text);
|
||||
std::string old_text = what_.str();
|
||||
what_.str(std::string());
|
||||
what_ << file << ':' << line;
|
||||
if (func) what_ << " in " << func << " threw ";
|
||||
if (child_name) {
|
||||
|
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "util/string_stream.hh"
|
||||
|
||||
#include <sstream>
|
||||
#include <exception>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
@ -15,6 +14,7 @@ class Exception : public std::exception {
|
||||
public:
|
||||
Exception() throw();
|
||||
virtual ~Exception() throw();
|
||||
Exception(const Exception& o) throw();
|
||||
|
||||
const char *what() const throw() { return what_.str().c_str(); }
|
||||
|
||||
@ -34,7 +34,7 @@ class Exception : public std::exception {
|
||||
typedef T Identity;
|
||||
};
|
||||
|
||||
StringStream what_;
|
||||
std::stringstream what_;
|
||||
};
|
||||
|
||||
/* This implements the normal operator<< for Exception and all its children.
|
||||
|
42
src/common/file_stream.h
Normal file
42
src/common/file_stream.h
Normal file
@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
|
||||
#include <boost/filesystem.hpp>
|
||||
#include <boost/filesystem/fstream.hpp>
|
||||
#include <boost/iostreams/filtering_stream.hpp>
|
||||
#include <boost/iostreams/filter/gzip.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "exception.h"
|
||||
|
||||
class InputFileStream {
|
||||
public:
|
||||
InputFileStream(const std::string& file)
|
||||
: file_(file), ifstream_(file_)
|
||||
{
|
||||
UTIL_THROW_IF2(!boost::filesystem::exists(file_),
|
||||
"File " << file << " does not exist");
|
||||
|
||||
if(file_.extension() == ".gz")
|
||||
istream_.push(boost::iostreams::gzip_decompressor());
|
||||
istream_.push(ifstream_);
|
||||
}
|
||||
|
||||
operator std::istream& () {
|
||||
return istream_;
|
||||
}
|
||||
|
||||
operator bool () {
|
||||
return istream_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend InputFileStream& operator>>(InputFileStream& stream, T& t) {
|
||||
stream.istream_ >> t;
|
||||
return stream;
|
||||
}
|
||||
|
||||
private:
|
||||
boost::filesystem::path file_;
|
||||
boost::filesystem::ifstream ifstream_;
|
||||
boost::iostreams::filtering_istream istream_;
|
||||
};
|
@ -3,19 +3,19 @@
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "types.h"
|
||||
#include "utils.h"
|
||||
#include "file_stream.h"
|
||||
|
||||
class Vocab {
|
||||
public:
|
||||
Vocab(const std::string& txt) {
|
||||
std::ifstream in(txt.c_str());
|
||||
Vocab(const std::string& path) {
|
||||
InputFileStream in(path);
|
||||
size_t c = 0;
|
||||
std::string line;
|
||||
while(std::getline(in, line)) {
|
||||
while(std::getline((std::istream&)in, line)) {
|
||||
str2id_[line] = c++;
|
||||
id2str_.push_back(line);
|
||||
}
|
||||
|
@ -3,36 +3,48 @@
|
||||
#include <vector>
|
||||
|
||||
#include "types.h"
|
||||
#include "file_stream.h"
|
||||
#include "scorer.h"
|
||||
#include "matrix.h"
|
||||
|
||||
typedef std::vector<Word> SrcTrgMap;
|
||||
typedef std::vector<float> Penalties;
|
||||
|
||||
class ApePenaltyState : public State {
|
||||
// Dummy, this scorer is stateless
|
||||
};
|
||||
|
||||
class ApePenalty : public Scorer {
|
||||
private:
|
||||
const SrcTrgMap& srcTrgMap_;
|
||||
const Penalties& penalties_;
|
||||
|
||||
public:
|
||||
ApePenalty(size_t sourceIndex)
|
||||
: Scorer(sourceIndex)
|
||||
ApePenalty(
|
||||
const SrcTrgMap& srcTrgMap,
|
||||
const Penalties& penalties,
|
||||
const YAML::Node& config,
|
||||
size_t tab)
|
||||
: Scorer(config, tab), srcTrgMap_(srcTrgMap),
|
||||
penalties_(penalties)
|
||||
{ }
|
||||
|
||||
// @TODO: make this work on GPU
|
||||
virtual void SetSource(const Sentence& source) {
|
||||
const Words& words = source.GetWords(sourceIndex_);
|
||||
const Vocab& svcb = God::GetSourceVocab(sourceIndex_);
|
||||
const Vocab& tvcb = God::GetTargetVocab();
|
||||
const Words& words = source.GetWords(tab_);
|
||||
|
||||
costs_.clear();
|
||||
costs_.resize(tvcb.size(), -1.0);
|
||||
for(auto& s : words) {
|
||||
const std::string& sstr = svcb[s];
|
||||
Word t = tvcb[sstr];
|
||||
costs_.resize(penalties_.size());
|
||||
algo::copy(penalties_.begin(), penalties_.end(), costs_.begin());
|
||||
|
||||
for(auto&& s : words) {
|
||||
Word t = srcTrgMap_[s];
|
||||
if(t != UNK && t < costs_.size())
|
||||
costs_[t] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// @TODO: make this work on GPU
|
||||
virtual void Score(const State& in,
|
||||
Prob& prob,
|
||||
State& out) {
|
||||
@ -65,11 +77,34 @@ class ApePenaltyLoader : public Loader {
|
||||
: Loader(config) {}
|
||||
|
||||
virtual void Load() {
|
||||
// @TODO: IDF weights
|
||||
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
|
||||
const Vocab& svcb = God::GetSourceVocab(tab);
|
||||
const Vocab& tvcb = God::GetTargetVocab();
|
||||
|
||||
srcTrgMap_.resize(svcb.size(), UNK);
|
||||
for(Word s = 0; s < svcb.size(); ++s)
|
||||
srcTrgMap_[s] = tvcb[svcb[s]];
|
||||
|
||||
penalties_.resize(tvcb.size(), -1.0);
|
||||
|
||||
if(Has("path")) {
|
||||
LOG(info) << "Loading APE penalties from " << Get<std::string>("path");
|
||||
YAML::Node penalties = YAML::Load(InputFileStream(Get<std::string>("path")));
|
||||
for(auto&& pair : penalties) {
|
||||
std::string entry = pair.first.as<std::string>();
|
||||
float penalty = pair.second.as<float>();
|
||||
penalties_[tvcb[entry]] = -penalty;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual ScorerPtr NewScorer(size_t taskId) {
|
||||
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
|
||||
return ScorerPtr(new ApePenalty(tab));
|
||||
return ScorerPtr(new ApePenalty(srcTrgMap_, penalties_,
|
||||
config_, tab));
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
SrcTrgMap srcTrgMap_;
|
||||
Penalties penalties_;
|
||||
};
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "config.h"
|
||||
#include "file_stream.h"
|
||||
#include "exception.h"
|
||||
|
||||
#define SET_OPTION(key, type) \
|
||||
@ -79,7 +80,7 @@ void Config::AddOptions(size_t argc, char** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
config_ = YAML::LoadFile(configPath);
|
||||
config_ = YAML::Load(InputFileStream(configPath));
|
||||
|
||||
SET_OPTION("n-best", bool)
|
||||
SET_OPTION("normalize", bool)
|
||||
|
@ -38,8 +38,10 @@ class EncoderDecoder : public Scorer {
|
||||
typedef EncoderDecoderState EDState;
|
||||
|
||||
public:
|
||||
EncoderDecoder(const Weights& model, size_t tabIndex)
|
||||
: Scorer(tabIndex), model_(model),
|
||||
EncoderDecoder(const Weights& model,
|
||||
const YAML::Node& config,
|
||||
size_t tab)
|
||||
: Scorer(config, tab), model_(model),
|
||||
encoder_(new Encoder(model_)), decoder_(new Decoder(model_))
|
||||
{}
|
||||
|
||||
@ -65,7 +67,7 @@ class EncoderDecoder : public Scorer {
|
||||
}
|
||||
|
||||
virtual void SetSource(const Sentence& source) {
|
||||
encoder_->GetContext(source.GetWords(sourceIndex_),
|
||||
encoder_->GetContext(source.GetWords(tab_),
|
||||
SourceContext_);
|
||||
}
|
||||
|
||||
@ -107,7 +109,7 @@ class EncoderDecoderLoader : public Loader {
|
||||
public:
|
||||
EncoderDecoderLoader(const YAML::Node& config)
|
||||
: Loader(config) {}
|
||||
|
||||
|
||||
virtual void Load() {
|
||||
std::string path = Get<std::string>("path");
|
||||
auto devices = God::Get<std::vector<size_t>>("devices");
|
||||
@ -126,7 +128,7 @@ class EncoderDecoderLoader : public Loader {
|
||||
size_t d = weights_[i]->GetDevice();
|
||||
cudaSetDevice(d);
|
||||
size_t tab = Has("tab") ? Get<size_t>("tab") : 0;
|
||||
return ScorerPtr(new EncoderDecoder(*weights_[i], tab));
|
||||
return ScorerPtr(new EncoderDecoder(*weights_[i], config_, tab));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "config.h"
|
||||
#include "scorer.h"
|
||||
#include "threadpool.h"
|
||||
#include "file_stream.h"
|
||||
#include "loader_factory.h"
|
||||
|
||||
God God::instance_;
|
||||
@ -76,7 +77,7 @@ void God::CleanUp() {
|
||||
|
||||
void God::LoadWeights(const std::string& path) {
|
||||
LOG(info) << "Reading weights from " << path;
|
||||
std::ifstream fweights(path.c_str());
|
||||
InputFileStream fweights(path);
|
||||
std::string name;
|
||||
float weight;
|
||||
size_t i = 0;
|
||||
|
@ -7,9 +7,6 @@
|
||||
#include "loader.h"
|
||||
#include "scorer.h"
|
||||
#include "logging.h"
|
||||
|
||||
// this should not be here
|
||||
#include "kenlm.h"
|
||||
|
||||
class Weights;
|
||||
|
||||
|
@ -28,8 +28,8 @@ typedef std::vector<Prob> Probs;
|
||||
|
||||
class Scorer {
|
||||
public:
|
||||
Scorer() : sourceIndex_(0) {}
|
||||
Scorer(size_t sourceIndex) : sourceIndex_(sourceIndex) {}
|
||||
Scorer(const YAML::Node& config, size_t tab)
|
||||
: config_(config), tab_(tab) {}
|
||||
|
||||
virtual ~Scorer() {}
|
||||
|
||||
@ -52,13 +52,14 @@ class Scorer {
|
||||
virtual void CleanUpAfterSentence() {}
|
||||
|
||||
protected:
|
||||
size_t sourceIndex_;
|
||||
const YAML::Node& config_;
|
||||
size_t tab_;
|
||||
};
|
||||
|
||||
class SourceIndependentScorer : public Scorer {
|
||||
public:
|
||||
SourceIndependentScorer() : Scorer(0) {}
|
||||
SourceIndependentScorer(size_t) : Scorer(0) {}
|
||||
SourceIndependentScorer(const YAML::Node& config, size_t)
|
||||
: Scorer(config, 0) {}
|
||||
|
||||
virtual ~SourceIndependentScorer() {}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user