mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-30 23:42:30 +03:00
257 lines
6.8 KiB
C++
257 lines
6.8 KiB
C++
// $Id$
|
|
|
|
#include "ConfusionNet.h"
|
|
#include <sstream>
|
|
|
|
#include "FactorCollection.h"
|
|
#include "Util.h"
|
|
#include "PhraseDictionaryTreeAdaptor.h"
|
|
#include "TranslationOptionCollectionConfusionNet.h"
|
|
#include "StaticData.h"
|
|
#include "Sentence.h"
|
|
#include "UserMessage.h"
|
|
|
|
namespace Moses
|
|
{
|
|
struct CNStats {
|
|
size_t created,destr,read,colls,words;
|
|
|
|
CNStats() : created(0),destr(0),read(0),colls(0),words(0) {}
|
|
~CNStats() {
|
|
print(std::cerr);
|
|
}
|
|
|
|
void createOne() {
|
|
++created;
|
|
}
|
|
void destroyOne() {
|
|
++destr;
|
|
}
|
|
|
|
void collect(const ConfusionNet& cn) {
|
|
++read;
|
|
colls+=cn.GetSize();
|
|
for(size_t i=0; i<cn.GetSize(); ++i)
|
|
words+=cn[i].size();
|
|
}
|
|
void print(std::ostream& out) const {
|
|
if(created>0) {
|
|
out<<"confusion net statistics:\n"
|
|
" created:\t"<<created<<"\n"
|
|
" destroyed:\t"<<destr<<"\n"
|
|
" succ. read:\t"<<read<<"\n"
|
|
" columns:\t"<<colls<<"\n"
|
|
" words:\t"<<words<<"\n"
|
|
" avg. word/column:\t"<<words/(1.0*colls)<<"\n"
|
|
" avg. cols/sent:\t"<<colls/(1.0*read)<<"\n"
|
|
"\n\n";
|
|
}
|
|
}
|
|
|
|
};
|
|
|
|
CNStats stats;
|
|
|
|
size_t ConfusionNet::GetColumnIncrement(size_t i, size_t j) const
|
|
{
|
|
(void) i;
|
|
(void) j;
|
|
return 1;
|
|
}
|
|
|
|
ConfusionNet::ConfusionNet()
|
|
: InputType()
|
|
{
|
|
stats.createOne();
|
|
}
|
|
ConfusionNet::~ConfusionNet()
|
|
{
|
|
stats.destroyOne();
|
|
}
|
|
|
|
ConfusionNet::ConfusionNet(Sentence const& s)
|
|
{
|
|
data.resize(s.GetSize());
|
|
for(size_t i=0; i<s.GetSize(); ++i)
|
|
data[i].push_back(std::make_pair(s.GetWord(i),0.0));
|
|
}
|
|
|
|
bool ConfusionNet::ReadF(std::istream& in,
|
|
const std::vector<FactorType>& factorOrder,
|
|
int format)
|
|
{
|
|
VERBOSE(1, "read confusion net with format "<<format<<"\n");
|
|
switch(format) {
|
|
case 0:
|
|
return ReadFormat0(in,factorOrder);
|
|
case 1:
|
|
return ReadFormat1(in,factorOrder);
|
|
default:
|
|
std::stringstream strme;
|
|
strme << "ERROR: unknown format '"<<format
|
|
<<"' in ConfusionNet::Read";
|
|
UserMessage::Add(strme.str());
|
|
}
|
|
return false;
|
|
}
|
|
|
|
int ConfusionNet::Read(std::istream& in,
|
|
const std::vector<FactorType>& factorOrder)
|
|
{
|
|
int rv=ReadF(in,factorOrder,0);
|
|
if(rv) stats.collect(*this);
|
|
return rv;
|
|
}
|
|
|
|
|
|
void ConfusionNet::String2Word(const std::string& s,Word& w,
|
|
const std::vector<FactorType>& factorOrder)
|
|
{
|
|
std::vector<std::string> factorStrVector = Tokenize(s, "|");
|
|
for(size_t i=0; i<factorOrder.size(); ++i)
|
|
w.SetFactor(factorOrder[i],
|
|
FactorCollection::Instance().AddFactor(Input,factorOrder[i],
|
|
factorStrVector[i]));
|
|
}
|
|
|
|
bool ConfusionNet::ReadFormat0(std::istream& in,
|
|
const std::vector<FactorType>& factorOrder)
|
|
{
|
|
Clear();
|
|
std::string line;
|
|
size_t numLinkParams = StaticData::Instance().GetNumLinkParams();
|
|
size_t numLinkWeights = StaticData::Instance().GetNumInputScores();
|
|
bool addRealWordCount = ((numLinkParams + 1) == numLinkWeights);
|
|
|
|
while(getline(in,line)) {
|
|
std::istringstream is(line);
|
|
std::string word;
|
|
|
|
Column col;
|
|
while(is>>word) {
|
|
Word w;
|
|
String2Word(word,w,factorOrder);
|
|
std::vector<float> probs(numLinkWeights,0.0);
|
|
for(size_t i=0; i<numLinkParams; i++) {
|
|
double prob;
|
|
if (!(is>>prob)) {
|
|
TRACE_ERR("ERROR: unable to parse CN input - bad link probability, or wrong number of scores\n");
|
|
return false;
|
|
}
|
|
if(prob<0.0) {
|
|
VERBOSE(1, "WARN: negative prob: "<<prob<<" ->set to 0.0\n");
|
|
prob=0.0;
|
|
} else if (prob>1.0) {
|
|
VERBOSE(1, "WARN: prob > 1.0 : "<<prob<<" -> set to 1.0\n");
|
|
prob=1.0;
|
|
}
|
|
probs[i] = (std::max(static_cast<float>(log(prob)),LOWEST_SCORE));
|
|
|
|
}
|
|
//store 'real' word count in last feature if we have one more weight than we do arc scores and not epsilon
|
|
if (addRealWordCount && word!=EPSILON && word!="")
|
|
probs[numLinkParams] = -1.0;
|
|
col.push_back(std::make_pair(w,probs));
|
|
}
|
|
if(col.size()) {
|
|
data.push_back(col);
|
|
ShrinkToFit(data.back());
|
|
} else break;
|
|
}
|
|
return !data.empty();
|
|
}
|
|
bool ConfusionNet::ReadFormat1(std::istream& in,
|
|
const std::vector<FactorType>& factorOrder)
|
|
{
|
|
Clear();
|
|
std::string line;
|
|
if(!getline(in,line)) return 0;
|
|
size_t s;
|
|
if(getline(in,line)) s=atoi(line.c_str());
|
|
else return 0;
|
|
data.resize(s);
|
|
for(size_t i=0; i<data.size(); ++i) {
|
|
if(!getline(in,line)) return 0;
|
|
std::istringstream is(line);
|
|
if(!(is>>s)) return 0;
|
|
std::string word;
|
|
double prob;
|
|
data[i].resize(s);
|
|
for(size_t j=0; j<s; ++j)
|
|
if(is>>word>>prob) {
|
|
//TODO: we are only reading one prob from this input format, should read many... but this function is unused anyway. -JS
|
|
data[i][j].second = std::vector<float> (1);
|
|
data[i][j].second.push_back((float) log(prob));
|
|
if(data[i][j].second[0]<0) {
|
|
VERBOSE(1, "WARN: neg costs: "<<data[i][j].second[0]<<" -> set to 0\n");
|
|
data[i][j].second[0]=0.0;
|
|
}
|
|
String2Word(word,data[i][j].first,factorOrder);
|
|
} else return 0;
|
|
}
|
|
return !data.empty();
|
|
}
|
|
|
|
void ConfusionNet::Print(std::ostream& out) const
|
|
{
|
|
out<<"conf net: "<<data.size()<<"\n";
|
|
for(size_t i=0; i<data.size(); ++i) {
|
|
out<<i<<" -- ";
|
|
for(size_t j=0; j<data[i].size(); ++j) {
|
|
out<<"("<<data[i][j].first.ToString()<<", ";
|
|
for(std::vector<float>::const_iterator scoreIterator = data[i][j].second.begin(); scoreIterator<data[i][j].second.end(); scoreIterator++) {
|
|
out<<", "<<*scoreIterator;
|
|
}
|
|
out<<") ";
|
|
}
|
|
out<<"\n";
|
|
}
|
|
out<<"\n\n";
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
#pragma warning(disable:4716)
|
|
#endif
|
|
Phrase ConfusionNet::GetSubString(const WordsRange&) const
|
|
{
|
|
TRACE_ERR("ERROR: call to ConfusionNet::GetSubString\n");
|
|
abort();
|
|
//return Phrase(Input);
|
|
}
|
|
|
|
std::string ConfusionNet::GetStringRep(const std::vector<FactorType> /* factorsToPrint */) const //not well defined yet
|
|
{
|
|
TRACE_ERR("ERROR: call to ConfusionNet::GeStringRep\n");
|
|
return "";
|
|
}
|
|
#ifdef _WIN32
|
|
#pragma warning(disable:4716)
|
|
#endif
|
|
const Word& ConfusionNet::GetWord(size_t) const
|
|
{
|
|
TRACE_ERR("ERROR: call to ConfusionNet::GetFactorArray\n");
|
|
abort();
|
|
}
|
|
#ifdef _WIN32
|
|
#pragma warning(default:4716)
|
|
#endif
|
|
std::ostream& operator<<(std::ostream& out,const ConfusionNet& cn)
|
|
{
|
|
cn.Print(out);
|
|
return out;
|
|
}
|
|
|
|
TranslationOptionCollection*
|
|
ConfusionNet::CreateTranslationOptionCollection(const TranslationSystem* system) const
|
|
{
|
|
size_t maxNoTransOptPerCoverage = StaticData::Instance().GetMaxNoTransOptPerCoverage();
|
|
float translationOptionThreshold = StaticData::Instance().GetTranslationOptionThreshold();
|
|
TranslationOptionCollection *rv= new TranslationOptionCollectionConfusionNet(system, *this, maxNoTransOptPerCoverage, translationOptionThreshold);
|
|
CHECK(rv);
|
|
return rv;
|
|
}
|
|
|
|
}
|
|
|
|
|