Feature function interface for use in scoring

This commit is contained in:
Barry Haddow 2012-11-02 23:30:51 +00:00
parent 9a1ac30889
commit 62fa6d6f28
11 changed files with 690 additions and 191 deletions

View File

@ -114,7 +114,7 @@ project : requirements
;
#Add directories here if you want their incidental targets too (i.e. tests).
build-projects lm util search moses/src mert moses-cmd/src moses-chart-cmd/src mira scripts regression-testing ;
build-projects lm util phrase-extract search moses/src mert moses-cmd/src moses-chart-cmd/src mira scripts regression-testing ;
alias programs : lm//programs moses-chart-cmd/src//moses_chart moses-cmd/src//programs OnDiskPt//CreateOnDiskPt OnDiskPt//queryOnDiskPt mert//programs misc//programs symal phrase-extract phrase-extract//lexical-reordering phrase-extract//extract-ghkm phrase-extract//pcfg-extract phrase-extract//pcfg-score biconcor mira//mira contrib/server//mosesserver ;

View File

@ -1,33 +1,41 @@
obj InputFileStream.o : InputFileStream.cpp : <include>. ;
alias InputFileStream : InputFileStream.o ..//z ;
obj OutputFileStream.o : OutputFileStream.cpp : <include>. ;
alias OutputFileStream : OutputFileStream.o ..//z ;
obj tables-core.o : tables-core.cpp : <include>. ;
obj domain.o : domain.cpp : <include>. ;
obj domain.o : domain.cpp : <include>. <include>.. ;
obj AlignmentPhrase.o : AlignmentPhrase.cpp : <include>. ;
obj PhraseAlignment.o : PhraseAlignment.cpp : <include>. ;
obj ScoreFeature.o : ScoreFeature.cpp : <include>. <include>.. ;
obj SentenceAlignment.o : SentenceAlignment.cpp : <include>. ;
obj SyntaxTree.o : SyntaxTree.cpp : <include>. ;
obj XmlTree.o : XmlTree.cpp : <include>. ;
alias filestreams : InputFileStream.cpp OutputFileStream.cpp : : : <include>. ;
alias filestreams : InputFileStream OutputFileStream : : : <include>. ;
alias trees : SyntaxTree.cpp tables-core.o XmlTree.o : : : <include>. ;
exe extract : tables-core.o SentenceAlignment.o extract.cpp OutputFileStream.cpp InputFileStream ../moses/src//ThreadPool ..//boost_iostreams ;
exe extract : tables-core.o SentenceAlignment.o extract.cpp filestreams ../moses/src//ThreadPool ..//boost_iostreams ;
exe extract-rules : tables-core.o SentenceAlignment.o SyntaxTree.o XmlTree.o SentenceAlignmentWithSyntax.cpp HoleCollection.cpp extract-rules.cpp ExtractedRule.cpp OutputFileStream.cpp InputFileStream ..//boost_iostreams ;
exe extract-rules : tables-core.o SentenceAlignment.o SyntaxTree.o XmlTree.o SentenceAlignmentWithSyntax.cpp HoleCollection.cpp extract-rules.cpp ExtractedRule.cpp filestreams ..//boost_iostreams ;
exe extract-lex : extract-lex.cpp InputFileStream ;
exe extract-lex : extract-lex.cpp filestreams ;
exe score : tables-core.o domain.o AlignmentPhrase.o score.cpp PhraseAlignment.cpp OutputFileStream.cpp InputFileStream ..//boost_iostreams ;
exe score : tables-core.o domain.o AlignmentPhrase.o score.cpp ScoreFeature.o PhraseAlignment.o filestreams ../util//kenutil ..//boost_iostreams ;
exe consolidate : consolidate.cpp tables-core.o OutputFileStream.cpp InputFileStream ..//boost_iostreams ;
exe consolidate : consolidate.cpp tables-core.o filestreams ..//boost_iostreams ;
exe consolidate-direct : consolidate-direct.cpp OutputFileStream.cpp InputFileStream ..//boost_iostreams ;
exe consolidate-direct : consolidate-direct.cpp filestreams ..//boost_iostreams ;
exe consolidate-reverse : consolidate-reverse.cpp tables-core.o InputFileStream ;
exe consolidate-reverse : consolidate-reverse.cpp tables-core.o filestreams ;
exe relax-parse : tables-core.o SyntaxTree.o XmlTree.o relax-parse.cpp InputFileStream ;
exe relax-parse : tables-core.o SyntaxTree.o XmlTree.o relax-parse.cpp filestreams ;
exe statistics : tables-core.o AlignmentPhrase.o statistics.cpp InputFileStream ;
exe statistics : tables-core.o AlignmentPhrase.o statistics.cpp filestreams ;
alias programs : extract extract-rules extract-lex score consolidate consolidate-direct consolidate-reverse relax-parse statistics ;
import testing ;
run ScoreFeatureTest.cpp tables-core.o domain.o ScoreFeature.o PhraseAlignment.o filestreams ../util//kenutil ..//boost_unit_test_framework ..//boost_iostreams : : test.domain ;

View File

@ -59,5 +59,54 @@ public:
};
class PhraseAlignment;
typedef std::vector<PhraseAlignment*> PhraseAlignmentCollection;
//typedef std::vector<PhraseAlignmentCollection> PhrasePairGroup;
class PhraseAlignmentCollectionOrderer
{
public:
bool operator()(const PhraseAlignmentCollection &collA, const PhraseAlignmentCollection &collB) const
{
assert(collA.size() > 0);
assert(collB.size() > 0);
const PhraseAlignment &objA = *collA[0];
const PhraseAlignment &objB = *collB[0];
bool ret = objA < objB;
return ret;
}
};
//typedef std::set<PhraseAlignmentCollection, PhraseAlignmentCollectionOrderer> PhrasePairGroup;
class PhrasePairGroup
{
private:
typedef std::set<PhraseAlignmentCollection, PhraseAlignmentCollectionOrderer> Coll;
Coll m_coll;
public:
typedef Coll::iterator iterator;
typedef Coll::const_iterator const_iterator;
typedef std::vector<const PhraseAlignmentCollection *> SortedColl;
std::pair<Coll::iterator,bool> insert ( const PhraseAlignmentCollection& obj );
const SortedColl &GetSortedColl() const
{ return m_sortedColl; }
size_t GetSize() const
{ return m_coll.size(); }
private:
SortedColl m_sortedColl;
};
}

View File

@ -0,0 +1,98 @@
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2012- University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "ScoreFeature.h"
#include "domain.h"
using namespace std;
namespace MosesTraining
{
const string& ScoreFeatureManager::usage() const
{
const static string& usage = "[--[Sparse]Domain[Indicator|Ratio|Subset|Bin] domain-file [bins]]" ;
return usage;
}
void ScoreFeatureManager::configure(const std::vector<std::string> args)
{
bool domainAdded = false;
bool sparseDomainAdded = false;
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].substr(0,8) == "--Domain") {
string type = args[i].substr(8);
++i;
UTIL_THROW_IF(i == args.size(), ScoreFeatureArgumentException, "Missing domain file");
string domainFile = args[i];
UTIL_THROW_IF(domainAdded, ScoreFeatureArgumentException,
"Only allowed one domain feature");
if (type == "Subset") {
m_features.push_back(ScoreFeaturePtr(new SubsetDomainFeature(domainFile)));
} else if (type == "Ratio") {
m_features.push_back(ScoreFeaturePtr(new RatioDomainFeature(domainFile)));
} else if (type == "Indicator") {
m_features.push_back(ScoreFeaturePtr(new IndicatorDomainFeature(domainFile)));
} else {
UTIL_THROW(ScoreFeatureArgumentException, "Unknown domain feature type " << type);
}
domainAdded = true;
m_includeSentenceId = true;
} else if (args[i].substr(0,14) == "--SparseDomain") {
string type = args[i].substr(14);
++i;
UTIL_THROW_IF(i == args.size(), ScoreFeatureArgumentException, "Missing domain file");
string domainFile = args[i];
UTIL_THROW_IF(sparseDomainAdded, ScoreFeatureArgumentException,
"Only allowed one sparse domain feature");
if (type == "Subset") {
m_features.push_back(ScoreFeaturePtr(new SparseSubsetDomainFeature(domainFile)));
} else if (type == "Ratio") {
m_features.push_back(ScoreFeaturePtr(new SparseRatioDomainFeature(domainFile)));
} else if (type == "Indicator") {
m_features.push_back(ScoreFeaturePtr(new SparseIndicatorDomainFeature(domainFile)));
} else {
UTIL_THROW(ScoreFeatureArgumentException, "Unknown domain feature type " << type);
}
sparseDomainAdded = true;
m_includeSentenceId = true;
}
}
}
bool ScoreFeatureManager::equals(const PhraseAlignment& lhs, const PhraseAlignment& rhs) const
{
for (size_t i = 0; i < m_features.size(); ++i) {
if (!m_features[i]->equals(lhs,rhs)) return false;
}
return true;
}
void ScoreFeatureManager::addFeatures(const ScoreFeatureContext& context,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
for (size_t i = 0; i < m_features.size(); ++i) {
m_features[i]->add(context, denseValues, sparseValues);
}
}
}

View File

@ -0,0 +1,136 @@
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2012- University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
/**
* This contains extra features that can be added to the scorer. To add a new feature:
* 1. Implement a subclass of ScoreFeature
* 2. Updated ScoreFeatureManager.configure() to configure your feature, and usage() to
* display usage info.
* 3. Write unit tests (see ScoreFeatureTest.cpp) and regression tests
**/
#pragma once
#include <string>
#include <map>
#include <vector>
#include <boost/shared_ptr.hpp>
#include "util/exception.hh"
#include "PhraseAlignment.h"
namespace MosesTraining
{
struct MaybeLog{
MaybeLog(bool useLog, float negativeLog):
m_useLog(useLog), m_negativeLog(negativeLog) {}
inline float operator() (float a) const
{ return m_useLog ? m_negativeLog*log(a) : a; }
float m_useLog;
float m_negativeLog;
};
class ScoreFeatureArgumentException : public util::Exception
{
public:
ScoreFeatureArgumentException() throw() {*this << "Unable to configure features: ";}
~ScoreFeatureArgumentException() throw() {}
};
/** Passed to each feature to be used to calculate its values */
struct ScoreFeatureContext
{
ScoreFeatureContext(
const PhraseAlignmentCollection &thePhrasePair,
float theCount, /* Total counts of all phrase pairs*/
const MaybeLog& theMaybeLog
) :
phrasePair(thePhrasePair),
count(theCount),
maybeLog(theMaybeLog)
{}
const PhraseAlignmentCollection& phrasePair;
float count;
MaybeLog maybeLog;
};
/**
* Abstract base class for extra features that can be added to the phrase table
* during scoring.
**/
class ScoreFeature
{
public:
/** Add the values for this feature function. */
virtual void add(const ScoreFeatureContext& context,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const = 0;
/** Return true if the two phrase pairs are equal from the point of this feature. Assume
that they already compare true according to PhraseAlignment.equals()
**/
virtual bool equals(const PhraseAlignment& lhs, const PhraseAlignment& rhs) const = 0;
virtual ~ScoreFeature() {}
};
typedef boost::shared_ptr<ScoreFeature> ScoreFeaturePtr;
class ScoreFeatureManager
{
public:
ScoreFeatureManager():
m_includeSentenceId(false) {}
/** To be appended to the score usage message */
const std::string& usage() const;
/** Pass the unused command-line arguments to configure the extra features */
void configure(const std::vector<std::string> args);
/** Add all the features */
void addFeatures(const ScoreFeatureContext& context,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
/**
* Used to tell if the PhraseAlignment should be considered the same by all
* extended features.
**/
bool equals(const PhraseAlignment& lhs, const PhraseAlignment& rhs) const;
const std::vector<ScoreFeaturePtr>& getFeatures() const {return m_features;}
/** Do we need to include sentence ids in phrase pairs? */
bool includeSentenceId() const {return m_includeSentenceId;}
private:
std::vector<ScoreFeaturePtr> m_features;
bool m_includeSentenceId;
};
}

View File

@ -0,0 +1,106 @@
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2012- University of Edinburgh
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
***********************************************************************/
#include "domain.h"
#include "ScoreFeature.h"
#include "tables-core.h"
#define BOOST_TEST_MODULE MosesTrainingScoreFeature
#include <boost/test/test_tools.hpp>
#include <boost/test/unit_test.hpp>
#include <boost/assign/list_of.hpp>
using namespace MosesTraining;
using namespace std;
//pesky global variables
namespace MosesTraining {
bool hierarchicalFlag = false;
Vocabulary vcbT;
Vocabulary vcbS;
}
const char *DomainFileLocation() {
if (boost::unit_test::framework::master_test_suite().argc < 2) {
return "test.domain";
}
return boost::unit_test::framework::master_test_suite().argv[1];
}
BOOST_AUTO_TEST_CASE(manager_configure_domain_except)
{
//Check that configure rejects illegal domain arg combinations
ScoreFeatureManager manager;
vector<string> args = boost::assign::list_of("--DomainRatio")("/dev/null")("--DomainIndicator")("/dev/null");
BOOST_CHECK_THROW(manager.configure(args), ScoreFeatureArgumentException);
args = boost::assign::list_of("--SparseDomainSubset")("/dev/null")("--SparseDomainRatio")("/dev/null");
BOOST_CHECK_THROW(manager.configure(args), ScoreFeatureArgumentException);
args = boost::assign::list_of("--SparseDomainBlah")("/dev/null");
BOOST_CHECK_THROW(manager.configure(args), ScoreFeatureArgumentException);
args = boost::assign::list_of("--DomainSubset");
BOOST_CHECK_THROW(manager.configure(args), ScoreFeatureArgumentException);
}
template <class Expected>
static void checkDomainConfigured(
const vector<string>& args)
{
ScoreFeatureManager manager;
manager.configure(args);
const std::vector<ScoreFeaturePtr>& features = manager.getFeatures();
BOOST_REQUIRE_EQUAL(features.size(), 1);
Expected* feature = dynamic_cast<Expected*>(features[0].get());
BOOST_REQUIRE(feature);
BOOST_CHECK(manager.includeSentenceId());
}
BOOST_AUTO_TEST_CASE(manager_config_domain)
{
checkDomainConfigured<RatioDomainFeature>
(boost::assign::list_of ("--DomainRatio")("/dev/null"));
checkDomainConfigured<IndicatorDomainFeature>
(boost::assign::list_of("--DomainIndicator")("/dev/null"));
checkDomainConfigured<SubsetDomainFeature>
(boost::assign::list_of("--DomainSubset")("/dev/null"));
checkDomainConfigured<SparseRatioDomainFeature>
(boost::assign::list_of("--SparseDomainRatio")("/dev/null"));
checkDomainConfigured<SparseIndicatorDomainFeature>
(boost::assign::list_of("--SparseDomainIndicator")("/dev/null"));
checkDomainConfigured<SparseSubsetDomainFeature>
(boost::assign::list_of("--SparseDomainSubset")("/dev/null"));
}
BOOST_AUTO_TEST_CASE(domain_equals)
{
SubsetDomainFeature feature(DomainFileLocation());
PhraseAlignment a1,a2,a3;
char buf1[] = "a ||| b ||| 0-0 ||| 1";
char buf2[] = "a ||| b ||| 0-0 ||| 2";
char buf3[] = "a ||| b ||| 0-0 ||| 3";
a1.create(buf1, 0, true); //domain a
a2.create(buf2, 1, true); //domain c
a3.create(buf3, 2, true); //domain c
BOOST_CHECK(feature.equals(a2,a3));
BOOST_CHECK(!feature.equals(a1,a3));
BOOST_CHECK(!feature.equals(a1,a3));
}

View File

@ -39,7 +39,7 @@ void Domain::load( const std::string &domainFileName ) {
}
// get domain name based on sentence number
string Domain::getDomainOfSentence( int sentenceId ) {
string Domain::getDomainOfSentence( int sentenceId ) const {
for(size_t i=0; i<spec.size(); i++) {
if (sentenceId <= spec[i].first) {
return spec[i].second;
@ -48,5 +48,128 @@ string Domain::getDomainOfSentence( int sentenceId ) {
return "undefined";
}
DomainFeature::DomainFeature(const string& domainFile)
{
//process domain file
m_domain.load(domainFile);
}
void DomainFeature::add(const ScoreFeatureContext& context,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
map< string, float > domainCount;
for(size_t i=0; i<context.phrasePair.size(); i++) {
string d = m_domain.getDomainOfSentence(context.phrasePair[i]->sentenceId );
if (domainCount.find( d ) == domainCount.end()) {
domainCount[d] = context.phrasePair[i]->count;
} else {
domainCount[d] += context.phrasePair[i]->count;
}
}
add(domainCount, context.count, context.maybeLog, denseValues, sparseValues);
}
void SubsetDomainFeature::add(const map<string,float>& domainCount,float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
if (m_domain.list.size() > 6) {
UTIL_THROW_IF(m_domain.list.size() > 6, ScoreFeatureArgumentException,
"too many domains for core domain subset features");
}
size_t bitmap = 0;
for(size_t bit = 0; bit < m_domain.list.size(); bit++) {
if (domainCount.find( m_domain.list[ bit ] ) != domainCount.end()) {
bitmap += 1 << bit;
}
}
for(size_t i = 1; i < (1 << m_domain.list.size()); i++) {
denseValues.push_back(maybeLog( (bitmap == i) ? 2.718 : 1 ));
}
}
void SparseSubsetDomainFeature::add(const map<string,float>& domainCount,float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
typedef vector<string>::const_iterator I;
ostringstream key;
key << "doms";
for (I i = m_domain.list.begin(); i != m_domain.list.end(); ++i) {
if (domainCount.find(*i) != domainCount.end()) {
key << "_" << *i;
}
}
sparseValues[key.str()] = 1;
}
void RatioDomainFeature::add(const map<string,float>& domainCount,float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
typedef vector< string >::const_iterator I;
for (I i = m_domain.list.begin(); i != m_domain.list.end(); i++ ) {
map<string,float>::const_iterator dci = domainCount.find(*i);
if (dci == domainCount.end() ) {
denseValues.push_back(maybeLog( 1 ));
} else {
denseValues.push_back(maybeLog(exp( dci->second / count ) ));
}
}
}
void SparseRatioDomainFeature::add(const map<string,float>& domainCount,float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
typedef map< string, float >::const_iterator I;
for (I i=domainCount.begin(); i != domainCount.end(); i++) {
sparseValues["domr_" + i->first] = (i->second / count);
}
}
void IndicatorDomainFeature::add(const map<string,float>& domainCount,float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
typedef vector< string >::const_iterator I;
for (I i = m_domain.list.begin(); i != m_domain.list.end(); i++ ) {
map<string,float>::const_iterator dci = domainCount.find(*i);
if (dci == domainCount.end() ) {
denseValues.push_back(maybeLog( 1 ));
} else {
denseValues.push_back(maybeLog(2.718));
}
}
}
void SparseIndicatorDomainFeature::add(const map<string,float>& domainCount,float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const
{
typedef map< string, float >::const_iterator I;
for (I i=domainCount.begin(); i != domainCount.end(); i++) {
sparseValues["dom_" + i->first] = 1;
}
}
bool DomainFeature::equals(const PhraseAlignment& lhs, const PhraseAlignment& rhs) const
{
return m_domain.getDomainOfSentence(lhs.sentenceId) ==
m_domain.getDomainOfSentence( rhs.sentenceId);
}
}

View File

@ -12,6 +12,8 @@
#include <map>
#include <cmath>
#include "ScoreFeature.h"
extern std::vector<std::string> tokenize( const char*);
namespace MosesTraining
@ -24,9 +26,114 @@ public:
std::vector< std::string > list;
std::map< std::string, int > name2id;
void load( const std::string &fileName );
std::string getDomainOfSentence( int sentenceId );
std::string getDomainOfSentence( int sentenceId ) const;
};
class DomainFeature : public ScoreFeature
{
public:
DomainFeature(const std::string& domainFile);
bool equals(const PhraseAlignment& lhs, const PhraseAlignment& rhs) const;
void add(const ScoreFeatureContext& context,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
protected:
/** Overriden in subclass */
virtual void add(const std::map<std::string,float>& domainCounts, float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const = 0;
Domain m_domain;
};
class SubsetDomainFeature : public DomainFeature
{
public:
SubsetDomainFeature(const std::string& domainFile) :
DomainFeature(domainFile) {}
protected:
virtual void add(const std::map<std::string,float>& domainCounts, float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
};
class SparseSubsetDomainFeature : public DomainFeature
{
public:
SparseSubsetDomainFeature(const std::string& domainFile) :
DomainFeature(domainFile) {}
protected:
virtual void add(const std::map<std::string,float>& domainCounts, float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
};
class IndicatorDomainFeature : public DomainFeature
{
public:
IndicatorDomainFeature(const std::string& domainFile) :
DomainFeature(domainFile) {}
protected:
virtual void add(const std::map<std::string,float>& domainCounts, float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
};
class SparseIndicatorDomainFeature : public DomainFeature
{
public:
SparseIndicatorDomainFeature(const std::string& domainFile) :
DomainFeature(domainFile) {}
protected:
virtual void add(const std::map<std::string,float>& domainCounts, float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
};
class RatioDomainFeature : public DomainFeature
{
public:
RatioDomainFeature(const std::string& domainFile) :
DomainFeature(domainFile) {}
protected:
virtual void add(const std::map<std::string,float>& domainCounts, float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
};
class SparseRatioDomainFeature : public DomainFeature
{
public:
SparseRatioDomainFeature(const std::string& domainFile) :
DomainFeature(domainFile) {}
protected:
virtual void add(const std::map<std::string,float>& domainCounts, float count,
const MaybeLog& maybeLog,
std::vector<float>& denseValues,
std::map<std::string,float>& sparseValues) const;
};
}
#endif

View File

@ -29,6 +29,7 @@
#include <algorithm>
#include "SafeGetline.h"
#include "ScoreFeature.h"
#include "tables-core.h"
#include "domain.h"
#include "PhraseAlignment.h"
@ -52,10 +53,9 @@ bool conditionOnTargetLhsFlag = false;
bool wordAlignmentFlag = false;
bool goodTuringFlag = false;
bool kneserNeyFlag = false;
#define COC_MAX 10
bool logProbFlag = false;
int negLogProb = 1;
inline float maybeLogProb( float a ) { return logProbFlag ? negLogProb*log(a) : a; }
#define COC_MAX 10
bool lexFlag = true;
bool unalignedFlag = false;
bool unalignedFWFlag = false;
@ -65,12 +65,6 @@ bool crossedNonTerm = false;
int countOfCounts[COC_MAX+1];
int totalDistinct = 0;
float minCountHierarchical = 0;
bool domainFlag = false;
bool domainRatioFlag = false;
bool domainSubsetFlag = false;
bool domainSparseFlag = false;
Domain *domain;
bool includeSentenceIdFlag = false;
Vocabulary vcbT;
Vocabulary vcbS;
@ -80,9 +74,9 @@ Vocabulary vcbS;
vector<string> tokenize( const char [] );
void writeCountOfCounts( const string &fileNameCountOfCounts );
void processPhrasePairs( vector< PhraseAlignment > & , ostream &phraseTableFile, bool isSingleton);
void processPhrasePairs( vector< PhraseAlignment > & , ostream &phraseTableFile, bool isSingleton, const ScoreFeatureManager& featureManager, const MaybeLog& maybeLog);
const PhraseAlignment &findBestAlignment(const PhraseAlignmentCollection &phrasePair );
void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float, int, ostream &phraseTableFile, bool isSingleton );
void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float, int, ostream &phraseTableFile, bool isSingleton, const ScoreFeatureManager& featureManager, const MaybeLog& maybeLog );
double computeLexicalTranslation( const PHRASE &, const PHRASE &, const PhraseAlignment & );
double computeUnalignedPenalty( const PHRASE &, const PHRASE &, const PhraseAlignment & );
set<string> functionWordList;
@ -99,8 +93,10 @@ int main(int argc, char* argv[])
cerr << "Score v2.0 written by Philipp Koehn\n"
<< "scoring methods for extracted rules\n";
ScoreFeatureManager featureManager;
if (argc < 4) {
cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--WordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--OutputNTLengths] [--PCFG] [--UnpairedExtractFormat] [--ConditionOnTargetLHS] [--[Sparse]Domain[Indicator|Ratio|Subset|Bin] domain-file [bins]] [--Singleton] [--CrossedNonTerm] \n";
cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--WordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--OutputNTLengths] [--PCFG] [--UnpairedExtractFormat] [--ConditionOnTargetLHS] [--Singleton] [--CrossedNonTerm] \n";
cerr << featureManager.usage() << endl;
exit(1);
}
string fileNameExtract = argv[1];
@ -109,6 +105,7 @@ int main(int argc, char* argv[])
string fileNameCountOfCounts;
char* fileNameFunctionWords = NULL;
char* fileNameDomain = NULL;
vector<string> featureArgs; //all unknown args passed to feature manager
for(int i=4; i<argc; i++) {
if (strcmp(argv[i],"inverse") == 0 || strcmp(argv[i],"--Inverse") == 0) {
@ -151,23 +148,7 @@ int main(int argc, char* argv[])
}
fileNameFunctionWords = argv[++i];
cerr << "using unaligned function word penalty with function words from " << fileNameFunctionWords << endl;
} else if (strcmp(argv[i],"--SparseDomainIndicator") == 0 ||
strcmp(argv[i],"--SparseDomainRatio") == 0 ||
strcmp(argv[i],"--SparseDomainSubset") == 0 ||
strcmp(argv[i],"--DomainIndicator") == 0 ||
strcmp(argv[i],"--DomainRatio") == 0 ||
strcmp(argv[i],"--DomainSubset") == 0) {
includeSentenceIdFlag = true;
domainFlag = true;
domainSparseFlag = strstr( argv[i], "Sparse" );
domainRatioFlag = strstr( argv[i], "Ratio" );
domainSubsetFlag = strstr( argv[i], "Subset" );
if (i+1==argc) {
cerr << "ERROR: specify domain info file with " << argv[i] << endl;
exit(1);
}
fileNameDomain = argv[++i];
} else if (strcmp(argv[i],"--LogProb") == 0) {
} else if (strcmp(argv[i],"--LogProb") == 0) {
logProbFlag = true;
cerr << "using log-probabilities\n";
} else if (strcmp(argv[i],"--NegLogProb") == 0) {
@ -187,11 +168,18 @@ int main(int argc, char* argv[])
crossedNonTerm = true;
cerr << "crossed non-term reordering feature\n";
} else {
cerr << "ERROR: unknown option " << argv[i] << endl;
exit(1);
featureArgs.push_back(argv[i]);
for (; i < argc && strncmp(argv[i], "--", 2); ++i) {
featureArgs.push_back(argv[i]);
}
}
}
MaybeLog maybeLogProb(logProbFlag, negLogProb);
//configure extra features
if (!inverseFlag) featureManager.configure(featureArgs);
// lexical translation table
if (lexFlag)
lexTable.load( fileNameLex );
@ -200,18 +188,6 @@ int main(int argc, char* argv[])
if (unalignedFWFlag)
loadFunctionWords( fileNameFunctionWords );
// load domain information
if (domainFlag) {
if (inverseFlag) {
domainFlag = false;
includeSentenceIdFlag = false;
}
else {
domain = new Domain;
domain->load( fileNameDomain );
}
}
// compute count of counts for Good Turing discounting
if (goodTuringFlag || kneserNeyFlag) {
for(int i=1; i<=COC_MAX; i++) countOfCounts[i] = 0;
@ -268,16 +244,14 @@ int main(int argc, char* argv[])
// create new phrase pair
PhraseAlignment phrasePair;
phrasePair.create( line, i, includeSentenceIdFlag );
phrasePair.create( line, i, featureManager.includeSentenceId());
lastCount = phrasePair.count;
lastPcfgSum = phrasePair.pcfgSum;
// only differs in count? just add count
if (lastPhrasePair != NULL
&& lastPhrasePair->equals( phrasePair )
&& (!domainFlag
|| domain->getDomainOfSentence( lastPhrasePair->sentenceId )
== domain->getDomainOfSentence( phrasePair.sentenceId ) )) {
&& lastPhrasePair->equals( phrasePair )
&& featureManager.equals(*lastPhrasePair, phrasePair)) {
lastPhrasePair->count += phrasePair.count;
lastPhrasePair->pcfgSum += phrasePair.pcfgSum;
continue;
@ -286,7 +260,7 @@ int main(int argc, char* argv[])
// if new source phrase, process last batch
if (lastPhrasePair != NULL &&
lastPhrasePair->GetSource() != phrasePair.GetSource()) {
processPhrasePairs( phrasePairsWithSameF, *phraseTableFile, isSingleton );
processPhrasePairs( phrasePairsWithSameF, *phraseTableFile, isSingleton, featureManager, maybeLogProb );
phrasePairsWithSameF.clear();
isSingleton = false;
@ -301,7 +275,7 @@ int main(int argc, char* argv[])
phrasePairsWithSameF.push_back( phrasePair );
lastPhrasePair = &phrasePairsWithSameF.back();
}
processPhrasePairs( phrasePairsWithSameF, *phraseTableFile, isSingleton );
processPhrasePairs( phrasePairsWithSameF, *phraseTableFile, isSingleton, featureManager, maybeLogProb );
phraseTableFile->flush();
if (phraseTableFile != &cout) {
@ -335,7 +309,7 @@ void writeCountOfCounts( const string &fileNameCountOfCounts )
countOfCountsFile.Close();
}
void processPhrasePairs( vector< PhraseAlignment > &phrasePair, ostream &phraseTableFile, bool isSingleton )
void processPhrasePairs( vector< PhraseAlignment > &phrasePair, ostream &phraseTableFile, bool isSingleton, const ScoreFeatureManager& featureManager, const MaybeLog& maybeLogProb )
{
if (phrasePair.size() == 0) return;
@ -376,7 +350,7 @@ void processPhrasePairs( vector< PhraseAlignment > &phrasePair, ostream &phraseT
for(iter = sortedColl.begin(); iter != sortedColl.end(); ++iter)
{
const PhraseAlignmentCollection &group = **iter;
outputPhrasePair( group, totalSource, phrasePairGroup.GetSize(), phraseTableFile, isSingleton );
outputPhrasePair( group, totalSource, phrasePairGroup.GetSize(), phraseTableFile, isSingleton, featureManager, maybeLogProb );
}
}
@ -493,9 +467,9 @@ void outputNTLengthProbs(ostream &phraseTableFile, const map<size_t, map<size_t,
}
bool calcCrossedNonTerm(int sourcePos, int targetPos, const std::vector< std::set<size_t> > &alignedToS)
bool calcCrossedNonTerm(size_t sourcePos, size_t targetPos, const std::vector< std::set<size_t> > &alignedToS)
{
for (int currSource = 0; currSource < alignedToS.size(); ++currSource)
for (size_t currSource = 0; currSource < alignedToS.size(); ++currSource)
{
if (currSource == sourcePos)
{ // skip
@ -526,7 +500,7 @@ int calcCrossedNonTerm(const PHRASE &phraseS, const PhraseAlignment &bestAlignme
{
const std::vector< std::set<size_t> > &alignedToS = bestAlignment.alignedToS;
for (int sourcePos = 0; sourcePos < alignedToS.size(); ++sourcePos)
for (size_t sourcePos = 0; sourcePos < alignedToS.size(); ++sourcePos)
{
const std::set<size_t> &targetSet = alignedToS[sourcePos];
@ -537,7 +511,7 @@ int calcCrossedNonTerm(const PHRASE &phraseS, const PhraseAlignment &bestAlignme
if (isNonTerm)
{
assert(targetSet.size() == 1);
int targetPos = *targetSet.begin();
size_t targetPos = *targetSet.begin();
bool ret = calcCrossedNonTerm(sourcePos, targetPos, alignedToS);
if (ret)
return 1;
@ -547,7 +521,8 @@ int calcCrossedNonTerm(const PHRASE &phraseS, const PhraseAlignment &bestAlignme
return 0;
}
void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float totalCount, int distinctCount, ostream &phraseTableFile, bool isSingleton )
void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float totalCount, int distinctCount, ostream &phraseTableFile, bool isSingleton, const ScoreFeatureManager& featureManager,
const MaybeLog& maybeLogProb )
{
if (phrasePair.size() == 0) return;
@ -559,17 +534,7 @@ void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float totalCo
count += phrasePair[i]->count;
}
// compute domain counts
map< string, float > domainCount;
if (domainFlag) {
for(size_t i=0; i<phrasePair.size(); i++) {
string d = domain->getDomainOfSentence( phrasePair[i]->sentenceId );
if (domainCount.find( d ) == domainCount.end())
domainCount[ d ] = phrasePair[i]->count;
else
domainCount[ d ] += phrasePair[i]->count;
}
}
// collect count of count statistics
if (goodTuringFlag || kneserNeyFlag) {
@ -620,19 +585,19 @@ void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float totalCo
// lexical translation probability
if (lexFlag) {
double lexScore = computeLexicalTranslation( phraseS, phraseT, bestAlignment);
phraseTableFile << maybeLogProb( lexScore );
phraseTableFile << maybeLogProb(lexScore );
}
// unaligned word penalty
if (unalignedFlag) {
double penalty = computeUnalignedPenalty( phraseS, phraseT, bestAlignment);
phraseTableFile << " " << maybeLogProb( penalty );
phraseTableFile << " " << maybeLogProb(penalty );
}
// unaligned function word penalty
if (unalignedFWFlag) {
double penalty = computeUnalignedFWPenalty( phraseS, phraseT, bestAlignment);
phraseTableFile << " " << maybeLogProb( penalty );
phraseTableFile << " " << maybeLogProb(penalty );
}
if (singletonFeature) {
@ -645,67 +610,21 @@ void outputPhrasePair(const PhraseAlignmentCollection &phrasePair, float totalCo
// target-side PCFG score
if (pcfgFlag && !inverseFlag) {
phraseTableFile << " " << maybeLogProb( pcfgScore );
phraseTableFile << " " << maybeLogProb(pcfgScore );
}
// domain count features
if (domainFlag) {
if (domainSparseFlag) {
// sparse, subset
if (domainSubsetFlag) {
typedef vector< string >::const_iterator I;
phraseTableFile << " doms";
for (I i = domain->list.begin(); i != domain->list.end(); i++ ) {
if (domainCount.find( *i ) != domainCount.end() ) {
phraseTableFile << "_" << *i;
}
}
phraseTableFile << " 1";
}
// sparse, indicator or ratio
else {
typedef map< string, float >::const_iterator I;
for (I i=domainCount.begin(); i != domainCount.end(); i++) {
if (domainRatioFlag) {
phraseTableFile << " domr_" << i->first << " " << (i->second / count);
}
else {
phraseTableFile << " dom_" << i->first << " 1";
}
}
}
}
// core, subset
else if (domainSubsetFlag) {
if (domain->list.size() > 6) {
cerr << "ERROR: too many domains for core domain subset features\n";
exit(1);
}
size_t bitmap = 0;
for(size_t bit = 0; bit < domain->list.size(); bit++) {
if (domainCount.find( domain->list[ bit ] ) != domainCount.end()) {
bitmap += 1 << bit;
}
}
for(size_t i = 1; i < (1 << domain->list.size()); i++) {
phraseTableFile << " " << maybeLogProb( (bitmap == i) ? 2.718 : 1 );
}
}
// core, indicator or ratio
else {
typedef vector< string >::const_iterator I;
for (I i = domain->list.begin(); i != domain->list.end(); i++ ) {
if (domainCount.find( *i ) == domainCount.end() ) {
phraseTableFile << " " << maybeLogProb( 1 );
}
else if (domainRatioFlag) {
phraseTableFile << " " << maybeLogProb( exp( domainCount[ *i ] / count ) );
}
else {
phraseTableFile << " " << maybeLogProb( 2.718 );
}
}
}
// extra features
ScoreFeatureContext context(phrasePair, count, maybeLogProb);
vector<float> extraDense;
map<string,float> extraSparse;
featureManager.addFeatures(context, extraDense, extraSparse);
for (size_t i = 0; i < extraDense.size(); ++i) {
phraseTableFile << " " << extraDense[i];
}
for (map<string,float>::const_iterator i = extraSparse.begin();
i != extraSparse.end(); ++i) {
phraseTableFile << " " << i->first << " " << i->second;
}
phraseTableFile << " ||| ";

View File

@ -12,55 +12,6 @@
namespace MosesTraining
{
class PhraseAlignment;
typedef std::vector<PhraseAlignment*> PhraseAlignmentCollection;
//typedef std::vector<PhraseAlignmentCollection> PhrasePairGroup;
class PhraseAlignmentCollectionOrderer
{
public:
bool operator()(const PhraseAlignmentCollection &collA, const PhraseAlignmentCollection &collB) const
{
assert(collA.size() > 0);
assert(collB.size() > 0);
const PhraseAlignment &objA = *collA[0];
const PhraseAlignment &objB = *collB[0];
bool ret = objA < objB;
return ret;
}
};
//typedef std::set<PhraseAlignmentCollection, PhraseAlignmentCollectionOrderer> PhrasePairGroup;
class PhrasePairGroup
{
private:
typedef std::set<PhraseAlignmentCollection, PhraseAlignmentCollectionOrderer> Coll;
Coll m_coll;
public:
typedef Coll::iterator iterator;
typedef Coll::const_iterator const_iterator;
typedef std::vector<const PhraseAlignmentCollection *> SortedColl;
std::pair<Coll::iterator,bool> insert ( const PhraseAlignmentCollection& obj );
const SortedColl &GetSortedColl() const
{ return m_sortedColl; }
size_t GetSize() const
{ return m_coll.size(); }
private:
SortedColl m_sortedColl;
};
class LexicalTable
{
public:

View File

@ -0,0 +1,2 @@
1 a
3 c