More reorganisation of options.

This commit is contained in:
Ulrich Germann 2015-08-06 22:52:34 +01:00
parent 524109e2ca
commit 6c1d9e2431
16 changed files with 290 additions and 150 deletions

View File

@ -159,13 +159,15 @@ int main(int argc, char* argv[])
} }
StaticData& SD = const_cast<StaticData&>(StaticData::Instance()); StaticData& SD = const_cast<StaticData&>(StaticData::Instance());
SD.SetUseLatticeMBR(true); LMBR_Options& lmbr = SD.options().lmbr;
MBR_Options& mbr = SD.options().mbr;
lmbr.enabled = true;
boost::shared_ptr<IOWrapper> ioWrapper(new IOWrapper); boost::shared_ptr<IOWrapper> ioWrapper(new IOWrapper);
if (!ioWrapper) { if (!ioWrapper) {
throw runtime_error("Failed to initialise IOWrapper"); throw runtime_error("Failed to initialise IOWrapper");
} }
size_t nBestSize = SD.GetMBRSize(); size_t nBestSize = mbr.size;
if (nBestSize <= 0) { if (nBestSize <= 0) {
throw new runtime_error("Non-positive size specified for n-best list"); throw new runtime_error("Non-positive size specified for n-best list");
@ -187,13 +189,13 @@ int main(int argc, char* argv[])
manager.CalcNBest(nBestSize, nBestList,true); manager.CalcNBest(nBestSize, nBestList,true);
//grid search //grid search
BOOST_FOREACH(float const& p, pgrid) { BOOST_FOREACH(float const& p, pgrid) {
SD.SetLatticeMBRPrecision(p); lmbr.precision = p;
BOOST_FOREACH(float const& r, rgrid) { BOOST_FOREACH(float const& r, rgrid) {
SD.SetLatticeMBRPRatio(r); lmbr.ratio = r;
BOOST_FOREACH(size_t const prune_i, prune_grid) { BOOST_FOREACH(size_t const prune_i, prune_grid) {
SD.SetLatticeMBRPruningFactor(size_t(prune_i)); lmbr.pruning_factor = prune_i;
BOOST_FOREACH(float const& scale_i, scale_grid) { BOOST_FOREACH(float const& scale_i, scale_grid) {
SD.SetMBRScale(scale_i); mbr.scale = scale_i;
size_t lineCount = source->GetTranslationId(); size_t lineCount = source->GetTranslationId();
cout << lineCount << " ||| " << p << " " cout << lineCount << " ||| " << p << " "
<< r << " " << size_t(prune_i) << " " << scale_i << r << " " << size_t(prune_i) << " " << scale_i

View File

@ -289,7 +289,7 @@ void ChartHypothesis::CleanupArcList()
const StaticData &staticData = StaticData::Instance(); const StaticData &staticData = StaticData::Instance();
size_t nBestSize = staticData.options().nbest.nbest_size; size_t nBestSize = staticData.options().nbest.nbest_size;
bool distinctNBest = (staticData.options().nbest.only_distinct bool distinctNBest = (staticData.options().nbest.only_distinct
|| staticData.UseMBR() || staticData.options().mbr.enabled
|| staticData.GetOutputSearchGraph() || staticData.GetOutputSearchGraph()
|| staticData.GetOutputSearchGraphHypergraph()); || staticData.GetOutputSearchGraphHypergraph());

View File

@ -363,13 +363,13 @@ CleanupArcList()
*/ */
const StaticData &staticData = StaticData::Instance(); const StaticData &staticData = StaticData::Instance();
size_t nBestSize = staticData.options().nbest.nbest_size; size_t nBestSize = staticData.options().nbest.nbest_size;
bool distinctNBest = (staticData.options().nbest.only_distinct || bool distinctNBest = (m_manager.options().nbest.only_distinct ||
staticData.GetLatticeSamplesSize() || staticData.GetLatticeSamplesSize() ||
staticData.UseMBR() || m_manager.options().mbr.enabled ||
staticData.GetOutputSearchGraph() || staticData.GetOutputSearchGraph() ||
staticData.GetOutputSearchGraphSLF() || staticData.GetOutputSearchGraphSLF() ||
staticData.GetOutputSearchGraphHypergraph() || staticData.GetOutputSearchGraphHypergraph() ||
staticData.UseLatticeMBR()); m_manager.options().lmbr.enabled);
if (!distinctNBest && m_arcList->size() > nBestSize * 5) { if (!distinctNBest && m_arcList->size() > nBestSize * 5) {
// prune arc list only if there too many arcs // prune arc list only if there too many arcs

View File

@ -490,13 +490,18 @@ bool Edge::operator< (const Edge& compare ) const
ostream& operator<< (ostream& out, const Edge& edge) ostream& operator<< (ostream& out, const Edge& edge)
{ {
out << "Head: " << edge.m_headNode->GetId() << ", Tail: " << edge.m_tailNode->GetId() << ", Score: " << edge.m_score << ", Phrase: " << edge.m_targetPhrase << endl; out << "Head: " << edge.m_headNode->GetId()
<< ", Tail: " << edge.m_tailNode->GetId()
<< ", Score: " << edge.m_score
<< ", Phrase: " << edge.m_targetPhrase << endl;
return out; return out;
} }
bool ascendingCoverageCmp(const Hypothesis* a, const Hypothesis* b) bool ascendingCoverageCmp(const Hypothesis* a, const Hypothesis* b)
{ {
return a->GetWordsBitmap().GetNumWordsCovered() < b->GetWordsBitmap().GetNumWordsCovered(); return (a->GetWordsBitmap().GetNumWordsCovered()
<
b->GetWordsBitmap().GetNumWordsCovered());
} }
void getLatticeMBRNBest(const Manager& manager, const TrellisPathList& nBestList, void getLatticeMBRNBest(const Manager& manager, const TrellisPathList& nBestList,
@ -509,15 +514,20 @@ void getLatticeMBRNBest(const Manager& manager, const TrellisPathList& nBestList
std::map < const Hypothesis*, set <const Hypothesis*> > outgoingHyps; std::map < const Hypothesis*, set <const Hypothesis*> > outgoingHyps;
map<const Hypothesis*, vector<Edge> > incomingEdges; map<const Hypothesis*, vector<Edge> > incomingEdges;
vector< float> estimatedScores; vector< float> estimatedScores;
manager.GetForwardBackwardSearchGraph(&connected, &connectedList, &outgoingHyps, &estimatedScores); manager.GetForwardBackwardSearchGraph(&connected, &connectedList,
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores, manager.GetBestHypothesis(), staticData.GetLatticeMBRPruningFactor(),staticData.GetMBRScale()); &outgoingHyps, &estimatedScores);
LMBR_Options const& lmbr = manager.options().lmbr;
MBR_Options const& mbr = manager.options().mbr;
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores,
manager.GetBestHypothesis(), lmbr.pruning_factor, mbr.scale);
calcNgramExpectations(connectedList, incomingEdges, ngramPosteriors,true); calcNgramExpectations(connectedList, incomingEdges, ngramPosteriors,true);
vector<float> mbrThetas = staticData.GetLatticeMBRThetas(); vector<float> mbrThetas = lmbr.theta;
float p = staticData.GetLatticeMBRPrecision(); float p = lmbr.precision;
float r = staticData.GetLatticeMBRPRatio(); float r = lmbr.ratio;
float mapWeight = staticData.GetLatticeMBRMapWeight(); float mapWeight = lmbr.map_weight;
if (mbrThetas.size() == 0) { //thetas not specified on the command line, use p and r instead if (mbrThetas.size() == 0) {
// thetas were not specified on the command line, so use p and r instead
mbrThetas.push_back(-1); //Theta 0 mbrThetas.push_back(-1); //Theta 0
mbrThetas.push_back(1/(bleu_order*p)); mbrThetas.push_back(1/(bleu_order*p));
for (size_t i = 2; i <= bleu_order; ++i) { for (size_t i = 2; i <= bleu_order; ++i) {
@ -537,7 +547,7 @@ void getLatticeMBRNBest(const Manager& manager, const TrellisPathList& nBestList
for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter, ++ctr) { for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter, ++ctr) {
const TrellisPath &path = **iter; const TrellisPath &path = **iter;
solutions.push_back(LatticeMBRSolution(path,iter==nBestList.begin())); solutions.push_back(LatticeMBRSolution(path,iter==nBestList.begin()));
solutions.back().CalcScore(ngramPosteriors,mbrThetas,mapWeight); solutions.back().CalcScore(ngramPosteriors, mbrThetas, mapWeight);
sort(solutions.begin(), solutions.end(), comparator); sort(solutions.begin(), solutions.end(), comparator);
while (solutions.size() > n) { while (solutions.size() > n) {
solutions.pop_back(); solutions.pop_back();
@ -568,7 +578,10 @@ const TrellisPath doConsensusDecoding(const Manager& manager, const TrellisPathL
map<const Hypothesis*, vector<Edge> > incomingEdges; map<const Hypothesis*, vector<Edge> > incomingEdges;
vector< float> estimatedScores; vector< float> estimatedScores;
manager.GetForwardBackwardSearchGraph(&connected, &connectedList, &outgoingHyps, &estimatedScores); manager.GetForwardBackwardSearchGraph(&connected, &connectedList, &outgoingHyps, &estimatedScores);
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores, manager.GetBestHypothesis(), staticData.GetLatticeMBRPruningFactor(),staticData.GetMBRScale()); LMBR_Options const& lmbr = manager.options().lmbr;
MBR_Options const& mbr = manager.options().mbr;
pruneLatticeFB(connectedList, outgoingHyps, incomingEdges, estimatedScores,
manager.GetBestHypothesis(), lmbr.pruning_factor, mbr.scale);
calcNgramExpectations(connectedList, incomingEdges, ngramExpectations,false); calcNgramExpectations(connectedList, incomingEdges, ngramExpectations,false);
//expected length is sum of expected unigram counts //expected length is sum of expected unigram counts

View File

@ -1492,7 +1492,7 @@ void Manager::OutputBest(OutputCollector *collector) const
// MAP decoding: best hypothesis // MAP decoding: best hypothesis
const Hypothesis* bestHypo = NULL; const Hypothesis* bestHypo = NULL;
if (!staticData.UseMBR()) { if (!options().mbr.enabled) {
bestHypo = GetBestHypothesis(); bestHypo = GetBestHypothesis();
if (bestHypo) { if (bestHypo) {
if (StaticData::Instance().GetOutputHypoScore()) { if (StaticData::Instance().GetOutputHypoScore()) {
@ -1534,7 +1534,7 @@ void Manager::OutputBest(OutputCollector *collector) const
// MBR decoding (n-best MBR, lattice MBR, consensus) // MBR decoding (n-best MBR, lattice MBR, consensus)
else { else {
// we first need the n-best translations // we first need the n-best translations
size_t nBestSize = staticData.GetMBRSize(); size_t nBestSize = options().mbr.size;
if (nBestSize <= 0) { if (nBestSize <= 0) {
cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl; cerr << "ERROR: negative size for number of MBR candidate translations not allowed (option mbr-size)" << endl;
exit(1); exit(1);
@ -1547,11 +1547,11 @@ void Manager::OutputBest(OutputCollector *collector) const
} }
// lattice MBR // lattice MBR
if (staticData.UseLatticeMBR()) { if (options().lmbr.enabled) {
if (staticData.options().nbest.enabled) { if (staticData.options().nbest.enabled) {
//lattice mbr nbest //lattice mbr nbest
vector<LatticeMBRSolution> solutions; vector<LatticeMBRSolution> solutions;
size_t n = min(nBestSize, staticData.options().nbest.nbest_size); size_t n = min(nBestSize, options().nbest.nbest_size);
getLatticeMBRNBest(*this,nBestList,solutions,n); getLatticeMBRNBest(*this,nBestList,solutions,n);
OutputLatticeMBRNBest(m_latticeNBestOut, solutions, translationId); OutputLatticeMBRNBest(m_latticeNBestOut, solutions, translationId);
} else { } else {
@ -1566,7 +1566,7 @@ void Manager::OutputBest(OutputCollector *collector) const
} }
// consensus decoding // consensus decoding
else if (staticData.UseConsensusDecoding()) { else if (options().search.consensus) {
const TrellisPath &conBestHypo = doConsensusDecoding(*this,nBestList); const TrellisPath &conBestHypo = doConsensusDecoding(*this,nBestList);
OutputBestHypo(conBestHypo, translationId, OutputBestHypo(conBestHypo, translationId,
staticData.GetReportSegmentation(), staticData.GetReportSegmentation(),
@ -1608,15 +1608,15 @@ void Manager::OutputNBest(OutputCollector *collector) const
const StaticData &staticData = StaticData::Instance(); const StaticData &staticData = StaticData::Instance();
long translationId = m_source.GetTranslationId(); long translationId = m_source.GetTranslationId();
if (staticData.UseLatticeMBR()) { if (options().lmbr.enabled) {
if (staticData.options().nbest.enabled) { if (staticData.options().nbest.enabled) {
collector->Write(translationId, m_latticeNBestOut.str()); collector->Write(translationId, m_latticeNBestOut.str());
} }
} else { } else {
TrellisPathList nBestList; TrellisPathList nBestList;
ostringstream out; ostringstream out;
CalcNBest(staticData.options().nbest.nbest_size, nBestList, CalcNBest(options().nbest.nbest_size, nBestList,
staticData.options().nbest.only_distinct); options().nbest.only_distinct);
OutputNBest(out, nBestList, staticData.GetOutputFactorOrder(), OutputNBest(out, nBestList, staticData.GetOutputFactorOrder(),
m_source.GetTranslationId(), m_source.GetTranslationId(),
staticData.GetReportSegmentation()); staticData.GetReportSegmentation());

View File

@ -463,64 +463,64 @@ StaticData
} }
void // void
StaticData // StaticData
::ini_mbr_options() // ::ini_mbr_options()
{ // {
// minimum Bayes risk decoding // // minimum Bayes risk decoding
m_parameter->SetParameter(m_mbr, "minimum-bayes-risk", false ); // m_parameter->SetParameter(m_mbr, "minimum-bayes-risk", false );
m_parameter->SetParameter<size_t>(m_mbrSize, "mbr-size", 200); // m_parameter->SetParameter<size_t>(m_mbrSize, "mbr-size", 200);
m_parameter->SetParameter(m_mbrScale, "mbr-scale", 1.0f); // m_parameter->SetParameter(m_mbrScale, "mbr-scale", 1.0f);
} // }
void // void
StaticData // StaticData
::ini_lmbr_options() // ::ini_lmbr_options()
{ // {
const PARAM_VEC *params; // const PARAM_VEC *params;
//lattice mbr // //lattice mbr
m_parameter->SetParameter(m_useLatticeMBR, "lminimum-bayes-risk", false ); // // m_parameter->SetParameter(m_useLatticeMBR, "lminimum-bayes-risk", false );
if (m_useLatticeMBR && m_mbr) { // // if (m_useLatticeMBR && m_mbr) {
cerr << "Error: Cannot use both n-best mbr and lattice mbr together" << endl; // // cerr << "Error: Cannot use both n-best mbr and lattice mbr together" << endl;
exit(1); // // exit(1);
} // // }
// lattice MBR // // // lattice MBR
if (m_useLatticeMBR) m_mbr = true; // // if (m_useLatticeMBR) m_mbr = true;
m_parameter->SetParameter<size_t>(m_lmbrPruning, "lmbr-pruning-factor", 30); // m_parameter->SetParameter<size_t>(m_lmbrPruning, "lmbr-pruning-factor", 30);
m_parameter->SetParameter(m_lmbrPrecision, "lmbr-p", 0.8f); // m_parameter->SetParameter(m_lmbrPrecision, "lmbr-p", 0.8f);
m_parameter->SetParameter(m_lmbrPRatio, "lmbr-r", 0.6f); // m_parameter->SetParameter(m_lmbrPRatio, "lmbr-r", 0.6f);
m_parameter->SetParameter(m_lmbrMapWeight, "lmbr-map-weight", 0.0f); // m_parameter->SetParameter(m_lmbrMapWeight, "lmbr-map-weight", 0.0f);
m_parameter->SetParameter(m_useLatticeHypSetForLatticeMBR, "lattice-hypo-set", false ); // m_parameter->SetParameter(m_useLatticeHypSetForLatticeMBR, "lattice-hypo-set", false );
params = m_parameter->GetParam("lmbr-thetas"); // params = m_parameter->GetParam("lmbr-thetas");
if (params) { // if (params) {
m_lmbrThetas = Scan<float>(*params); // m_lmbrThetas = Scan<float>(*params);
} // }
} // }
void // void
StaticData // StaticData
::ini_consensus_decoding_options() // ::ini_consensus_decoding_options()
{ // {
//consensus decoding // //consensus decoding
m_parameter->SetParameter(m_useConsensusDecoding, "consensus-decoding", false ); // m_parameter->SetParameter(m_useConsensusDecoding, "consensus-decoding", false );
if (m_useConsensusDecoding && m_mbr) { // if (m_useConsensusDecoding && m_mbr) {
cerr<< "Error: Cannot use consensus decoding together with mbr" << endl; // cerr<< "Error: Cannot use consensus decoding together with mbr" << endl;
exit(1); // exit(1);
} // }
if (m_useConsensusDecoding) m_mbr=true; // if (m_useConsensusDecoding) m_mbr=true;
} // }
void // void
StaticData // StaticData
::ini_mira_options() // ::ini_mira_options()
{ // {
//mira training // //mira training
m_parameter->SetParameter(m_mira, "mira", false ); // m_parameter->SetParameter(m_mira, "mira", false );
} // }
bool StaticData::LoadData(Parameter *parameter) bool StaticData::LoadData(Parameter *parameter)
{ {
@ -559,15 +559,19 @@ bool StaticData::LoadData(Parameter *parameter)
// ini_cube_pruning_options(); // ini_cube_pruning_options();
ini_oov_options(); ini_oov_options();
ini_mbr_options(); // ini_mbr_options();
ini_lmbr_options(); // ini_lmbr_options();
ini_consensus_decoding_options(); // ini_consensus_decoding_options();
ini_mira_options(); // ini_mira_options();
// set m_nbest_options.enabled = true if necessary: // set m_nbest_options.enabled = true if necessary:
if (m_mbr || m_useLatticeMBR || m_outputSearchGraph || m_outputSearchGraphSLF if (m_options.mbr.enabled
|| m_mira || m_outputSearchGraphHypergraph || m_useConsensusDecoding || m_options.mira
|| m_options.search.consensus
|| m_outputSearchGraph
|| m_outputSearchGraphSLF
|| m_outputSearchGraphHypergraph
#ifdef HAVE_PROTOBUF #ifdef HAVE_PROTOBUF
|| m_outputSearchGraphPB || m_outputSearchGraphPB
#endif #endif

View File

@ -145,18 +145,18 @@ protected:
XmlInputType m_xmlInputType; //! method for handling sentence XML input XmlInputType m_xmlInputType; //! method for handling sentence XML input
std::pair<std::string,std::string> m_xmlBrackets; //! strings to use as XML tags' opening and closing brackets. Default are "<" and ">" std::pair<std::string,std::string> m_xmlBrackets; //! strings to use as XML tags' opening and closing brackets. Default are "<" and ">"
bool m_mbr; //! use MBR decoder // bool m_mbr; //! use MBR decoder
bool m_useLatticeMBR; //! use MBR decoder // bool m_useLatticeMBR; //! use MBR decoder
bool m_mira; // do mira training // bool m_mira; // do mira training
bool m_useConsensusDecoding; //! Use Consensus decoding (DeNero et al 2009) // bool m_useConsensusDecoding; //! Use Consensus decoding (DeNero et al 2009)
size_t m_mbrSize; //! number of translation candidates considered // size_t m_mbrSize; //! number of translation candidates considered
float m_mbrScale; //! scaling factor for computing marginal probability of candidate translation // float m_mbrScale; //! scaling factor for computing marginal probability of candidate translation
size_t m_lmbrPruning; //! average number of nodes per word wanted in pruned lattice // size_t m_lmbrPruning; //! average number of nodes per word wanted in pruned lattice
std::vector<float> m_lmbrThetas; //! theta(s) for lattice mbr calculation // std::vector<float> m_lmbrThetas; //! theta(s) for lattice mbr calculation
bool m_useLatticeHypSetForLatticeMBR; //! to use nbest as hypothesis set during lattice MBR // bool m_useLatticeHypSetForLatticeMBR; //! to use nbest as hypothesis set during lattice MBR
float m_lmbrPrecision; //! unigram precision theta - see Tromble et al 08 for more details // float m_lmbrPrecision; //! unigram precision theta - see Tromble et al 08 for more details
float m_lmbrPRatio; //! decaying factor for ngram thetas - see Tromble et al 08 for more details // float m_lmbrPRatio; //! decaying factor for ngram thetas - see Tromble et al 08 for more details
float m_lmbrMapWeight; //! Weight given to the map solution. See Kumar et al 09 for details // float m_lmbrMapWeight; //! Weight given to the map solution. See Kumar et al 09 for details
size_t m_lmcache_cleanup_threshold; //! number of translations after which LM claenup is performed (0=never, N=after N translations; default is 1) size_t m_lmcache_cleanup_threshold; //! number of translations after which LM claenup is performed (0=never, N=after N translations; default is 1)
bool m_lmEnableOOVFeature; bool m_lmEnableOOVFeature;
@ -512,55 +512,55 @@ public:
const std::string& GetFactorDelimiter() const { const std::string& GetFactorDelimiter() const {
return m_factorDelimiter; return m_factorDelimiter;
} }
bool UseMBR() const { // bool UseMBR() const {
return m_mbr; // return m_mbr;
} // }
bool UseLatticeMBR() const { // bool UseLatticeMBR() const {
return m_useLatticeMBR ; // return m_useLatticeMBR ;
} // }
bool UseConsensusDecoding() const { // bool UseConsensusDecoding() const {
return m_useConsensusDecoding; // return m_useConsensusDecoding;
} // }
void SetUseLatticeMBR(bool flag) { // void SetUseLatticeMBR(bool flag) {
m_useLatticeMBR = flag; // m_useLatticeMBR = flag;
} // }
size_t GetMBRSize() const { // size_t GetMBRSize() const {
return m_mbrSize; // return m_mbrSize;
} // }
float GetMBRScale() const { // float GetMBRScale() const {
return m_mbrScale; // return m_mbrScale;
} // }
void SetMBRScale(float scale) { // void SetMBRScale(float scale) {
m_mbrScale = scale; // m_mbrScale = scale;
} // }
size_t GetLatticeMBRPruningFactor() const { // size_t GetLatticeMBRPruningFactor() const {
return m_lmbrPruning; // return m_lmbrPruning;
} // }
void SetLatticeMBRPruningFactor(size_t prune) { // void SetLatticeMBRPruningFactor(size_t prune) {
m_lmbrPruning = prune; // m_lmbrPruning = prune;
} // }
const std::vector<float>& GetLatticeMBRThetas() const { // const std::vector<float>& GetLatticeMBRThetas() const {
return m_lmbrThetas; // return m_lmbrThetas;
} // }
bool UseLatticeHypSetForLatticeMBR() const { // bool UseLatticeHypSetForLatticeMBR() const {
return m_useLatticeHypSetForLatticeMBR; // return m_useLatticeHypSetForLatticeMBR;
} // }
float GetLatticeMBRPrecision() const { // float GetLatticeMBRPrecision() const {
return m_lmbrPrecision; // return m_lmbrPrecision;
} // }
void SetLatticeMBRPrecision(float p) { // void SetLatticeMBRPrecision(float p) {
m_lmbrPrecision = p; // m_lmbrPrecision = p;
} // }
float GetLatticeMBRPRatio() const { // float GetLatticeMBRPRatio() const {
return m_lmbrPRatio; // return m_lmbrPRatio;
} // }
void SetLatticeMBRPRatio(float r) { // void SetLatticeMBRPRatio(float r) {
m_lmbrPRatio = r; // m_lmbrPRatio = r;
} // }
float GetLatticeMBRMapWeight() const { // float GetLatticeMBRMapWeight() const {
return m_lmbrMapWeight; // return m_lmbrMapWeight;
} // }
// bool UseTimeout() const { // bool UseTimeout() const {
// return m_timeout; // return m_timeout;

View File

@ -92,7 +92,7 @@ float calculate_score(const vector< vector<const Factor*> > & sents, int ref, in
const TrellisPath doMBR(const TrellisPathList& nBestList) const TrellisPath doMBR(const TrellisPathList& nBestList)
{ {
float marginal = 0; float marginal = 0;
float mbr_scale = StaticData::Instance().options().mbr.scale;
vector<float> joint_prob_vec; vector<float> joint_prob_vec;
vector< vector<const Factor*> > translations; vector< vector<const Factor*> > translations;
float joint_prob; float joint_prob;
@ -104,14 +104,13 @@ const TrellisPath doMBR(const TrellisPathList& nBestList)
float maxScore = -1e20; float maxScore = -1e20;
for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter) { for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter) {
const TrellisPath &path = **iter; const TrellisPath &path = **iter;
float score = StaticData::Instance().GetMBRScale() float score = mbr_scale * path.GetScoreBreakdown()->GetWeightedScore();
* path.GetScoreBreakdown()->GetWeightedScore();
if (maxScore < score) maxScore = score; if (maxScore < score) maxScore = score;
} }
for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter) { for (iter = nBestList.begin() ; iter != nBestList.end() ; ++iter) {
const TrellisPath &path = **iter; const TrellisPath &path = **iter;
joint_prob = UntransformScore(StaticData::Instance().GetMBRScale() * path.GetScoreBreakdown()->GetWeightedScore() - maxScore); joint_prob = UntransformScore(mbr_scale * path.GetScoreBreakdown()->GetWeightedScore() - maxScore);
marginal += joint_prob; marginal += joint_prob;
joint_prob_vec.push_back(joint_prob); joint_prob_vec.push_back(joint_prob);

View File

@ -19,6 +19,11 @@ namespace Moses
if (!reordering.init(param)) return false; if (!reordering.init(param)) return false;
if (!context.init(param)) return false; if (!context.init(param)) return false;
if (!input.init(param)) return false; if (!input.init(param)) return false;
if (!mbr.init(param)) return false;
if (!lmbr.init(param)) return false;
param.SetParameter(mira, "mira", false);
return sanity_check(); return sanity_check();
} }
@ -26,6 +31,26 @@ namespace Moses
AllOptions:: AllOptions::
sanity_check() sanity_check()
{ {
using namespace std;
if (lmbr.enabled)
{
if (mbr.enabled)
{
cerr << "Error: Cannot use both n-best mbr and lattice mbr together" << endl;
return false;
}
mbr.enabled = true;
}
if (search.consensus)
{
if (mbr.enabled)
{
cerr << "Error: Cannot use consensus decoding together with mbr" << endl;
return false;
}
mbr.enabled = true;
}
return true; return true;
} }
} }

View File

@ -8,6 +8,8 @@
#include "ReorderingOptions.h" #include "ReorderingOptions.h"
#include "ContextParameters.h" #include "ContextParameters.h"
#include "InputOptions.h" #include "InputOptions.h"
#include "MBR_Options.h"
#include "LMBR_Options.h"
namespace Moses namespace Moses
{ {
@ -20,6 +22,11 @@ namespace Moses
ReorderingOptions reordering; ReorderingOptions reordering;
ContextParameters context; ContextParameters context;
InputOptions input; InputOptions input;
MBR_Options mbr;
LMBR_Options lmbr;
bool mira;
// StackOptions stack; // StackOptions stack;
// BeamSearchOptions beam; // BeamSearchOptions beam;
bool init(Parameter const& param); bool init(Parameter const& param);

View File

@ -0,0 +1,24 @@
// -*- mode: c++; indent-tabs-mode: nil; tab-width: 2 -*-
#include "LMBR_Options.h"
namespace Moses {
bool
LMBR_Options::
init(Parameter const& param)
{
param.SetParameter(enabled, "lminimum-bayes-risk", false);
param.SetParameter(ratio, "lmbr-r", 0.6f);
param.SetParameter(precision, "lmbr-p", 0.8f);
param.SetParameter(map_weight, "lmbr-map-weight", 0.0f);
param.SetParameter(pruning_factor, "lmbr-pruning-factor", size_t(30));
param.SetParameter(use_lattice_hyp_set, "lattice-hypo-set", false);
PARAM_VEC const* params = param.GetParam("lmbr-thetas");
if (params) theta = Scan<float>(*params);
return true;
}
}

View File

@ -0,0 +1,25 @@
// -*- mode: c++; indent-tabs-mode: nil; tab-width: 2 -*-
#pragma once
#include <string>
#include <vector>
#include "moses/Parameter.h"
namespace Moses
{
// Options for mimum bayes risk decoding
struct
LMBR_Options
{
bool enabled;
bool use_lattice_hyp_set; //! to use nbest as hypothesis set during lattice MBR
float precision; //! unigram precision theta - see Tromble et al 08 for more details
float ratio; //! decaying factor for ngram thetas - see Tromble et al 08
float map_weight; //! Weight given to the map solution. See Kumar et al 09
size_t pruning_factor; //! average number of nodes per word wanted in pruned lattice
std::vector<float> theta; //! theta(s) for lattice mbr calculation
bool init(Parameter const& param);
LMBR_Options() {}
};
}

View File

@ -0,0 +1,16 @@
// -*- mode: c++; indent-tabs-mode: nil; tab-width: 2 -*-
#include "MBR_Options.h"
namespace Moses {
bool
MBR_Options::
init(Parameter const& param)
{
param.SetParameter(enabled, "minimum-bayes-risk", false);
param.SetParameter<size_t>(size, "mbr-size", 200);
param.SetParameter(scale, "mbr-scale", 1.0f);
return true;
}
}

View File

@ -0,0 +1,21 @@
// -*- mode: c++; indent-tabs-mode: nil; tab-width: 2 -*-
#pragma once
#include <string>
#include "moses/Parameter.h"
namespace Moses
{
// Options for mimum bayes risk decoding
struct
MBR_Options
{
bool enabled;
size_t size; //! number of translation candidates considered
float scale; /*! scaling factor for computing marginal probability
* of candidate translation */
bool init(Parameter const& param);
MBR_Options() {}
};
}

View File

@ -31,6 +31,8 @@ namespace Moses
DEFAULT_MAX_PART_TRANS_OPT_SIZE); DEFAULT_MAX_PART_TRANS_OPT_SIZE);
param.SetParameter(consensus, "consensus-decoding", false);
// transformation to log of a few scores // transformation to log of a few scores
beam_width = TransformScore(beam_width); beam_width = TransformScore(beam_width);
trans_opt_threshold = TransformScore(trans_opt_threshold); trans_opt_threshold = TransformScore(trans_opt_threshold);

View File

@ -24,6 +24,8 @@ namespace Moses
int timeout; int timeout;
bool consensus; //! Use Consensus decoding (DeNero et al 2009)
// reordering options // reordering options
// bool reorderingConstraint; //! use additional reordering constraints // bool reorderingConstraint; //! use additional reordering constraints
// bool useEarlyDistortionCost; // bool useEarlyDistortionCost;