handle terminal alignments for hierarchical models

This commit is contained in:
Eva 2012-04-19 11:08:06 -07:00
parent 0986ca1b5d
commit 191a418aea
15 changed files with 143 additions and 48 deletions

View File

@ -26,7 +26,7 @@ namespace Moses
{
void AlignmentInfo::BuildNonTermIndexMap()
{
{
if (m_collection.empty()) {
return;
}
@ -40,9 +40,9 @@ void AlignmentInfo::BuildNonTermIndexMap()
m_nonTermIndexMap.resize(maxIndex+1, NOT_FOUND);
size_t i = 0;
for (p = begin(); p != end(); ++p) {
m_nonTermIndexMap[p->second] = i++;
//std::cerr << "nt point: " << p->second << " -> " << i << std::endl;
m_nonTermIndexMap[p->second] = i++;
}
}
bool compare_target(const std::pair<size_t,size_t> *a, const std::pair<size_t,size_t> *b) {

View File

@ -19,26 +19,28 @@
#pragma once
#include <iostream>
#include <ostream>
#include <set>
#include <vector>
#include <cstdlib>
namespace Moses
{
class AlignmentInfoCollection;
// Collection of non-terminal alignment pairs, ordered by source index.
// Collection of non-terminal/terminal alignment pairs, ordered by source index.
class AlignmentInfo
{
typedef std::set<std::pair<size_t,size_t> > CollType;
friend std::ostream& operator<<(std::ostream &, const AlignmentInfo &);
friend struct AlignmentInfoOrderer;
friend class AlignmentInfoCollection;
public:
typedef std::set<std::pair<size_t,size_t> > CollType;
typedef std::vector<size_t> NonTermIndexMap;
typedef std::vector<size_t> TermIndexMap;
typedef CollType::const_iterator const_iterator;
const_iterator begin() const { return m_collection.begin(); }
@ -50,7 +52,17 @@ class AlignmentInfo
const NonTermIndexMap &GetNonTermIndexMap() const {
return m_nonTermIndexMap;
}
// only used for hierarchical models, contains terminal alignments
const CollType &GetTerminalAlignments() const {
return m_terminalCollection;
}
// for phrase-based models, this contains all alignments, for hierarchical models only the NT alignments
const CollType &GetAlignments() const {
return m_collection;
}
std::vector< const std::pair<size_t,size_t>* > GetSortedAlignments() const;
private:
@ -60,10 +72,29 @@ class AlignmentInfo
{
BuildNonTermIndexMap();
}
// use this for hierarchical models
explicit AlignmentInfo(const std::set<std::pair<size_t,size_t> > &pairs, int* indicator)
{
// split alignment set in terminals and non-terminals
std::set<std::pair<size_t,size_t> > terminalSet;
std::set<std::pair<size_t,size_t> > nonTerminalSet;
std::set<std::pair<size_t,size_t> >::iterator iter;
for (iter = pairs.begin(); iter != pairs.end(); ++iter) {
if (*indicator == 1) nonTerminalSet.insert(*iter);
else terminalSet.insert(*iter);
indicator++;
}
m_collection = nonTerminalSet;
m_terminalCollection = terminalSet;
BuildNonTermIndexMap();
}
void BuildNonTermIndexMap();
CollType m_collection;
CollType m_terminalCollection;
NonTermIndexMap m_nonTermIndexMap;
};
@ -72,6 +103,8 @@ class AlignmentInfo
struct AlignmentInfoOrderer
{
bool operator()(const AlignmentInfo &a, const AlignmentInfo &b) const {
if (a.m_collection == b.m_collection)
return a.m_terminalCollection < b.m_terminalCollection;
return a.m_collection < b.m_collection;
}
};

View File

@ -43,4 +43,12 @@ const AlignmentInfo *AlignmentInfoCollection::Add(
return &(*ret.first);
}
const AlignmentInfo *AlignmentInfoCollection::Add(
const std::set<std::pair<size_t,size_t> > &pairs, int* indicator)
{
std::pair<AlignmentInfoSet::iterator, bool> ret =
m_collection.insert(AlignmentInfo(pairs, indicator));
return &(*ret.first);
}
}

View File

@ -37,6 +37,7 @@ class AlignmentInfoCollection
// contains such an object then returns a pointer to it; otherwise a new
// one is inserted.
const AlignmentInfo *Add(const std::set<std::pair<size_t,size_t> > &);
const AlignmentInfo *Add(const std::set<std::pair<size_t,size_t> > &, int* indicator);
// Returns a pointer to an empty AlignmentInfo object.
const AlignmentInfo &GetEmptyAlignmentInfo() const;

View File

@ -72,7 +72,7 @@ bool RuleTableLoaderCompact::Load(const std::vector<FactorType> &input,
// Load alignments.
std::vector<const AlignmentInfo *> alignmentSets;
LoadAlignmentSection(reader, alignmentSets);
LoadAlignmentSection(reader, alignmentSets, sourcePhrases);
// Load rules.
if (!LoadRuleSection(reader, vocab, sourcePhrases, targetPhrases,
@ -136,7 +136,7 @@ void RuleTableLoaderCompact::LoadPhraseSection(
}
void RuleTableLoaderCompact::LoadAlignmentSection(
LineReader &reader, std::vector<const AlignmentInfo *> &alignmentSets)
LineReader &reader, std::vector<const AlignmentInfo *> &alignmentSets, std::vector<Phrase> &sourcePhrases)
{
// Read alignment set count.
reader.ReadLine();
@ -153,13 +153,16 @@ void RuleTableLoaderCompact::LoadAlignmentSection(
reader.ReadLine();
Tokenize(tokens, reader.m_line);
std::vector<std::string>::const_iterator p;
int indicator[tokens.size()];
size_t index = 0;
for (p = tokens.begin(); p != tokens.end(); ++p) {
points.clear();
Tokenize<size_t>(points, *p, "-");
std::pair<size_t, size_t> alignmentPair(points[0], points[1]);
alignmentInfo.insert(alignmentPair);
indicator[index++] = sourcePhrases[i].GetWord(points[0]).IsNonTerminal() ? 1: 0;
}
alignmentSets[i] = AlignmentInfoCollection::Instance().Add(alignmentInfo);
alignmentSets[i] = AlignmentInfoCollection::Instance().Add(alignmentInfo, indicator);
}
}

View File

@ -70,7 +70,8 @@ class RuleTableLoaderCompact : public RuleTableLoader
std::vector<size_t> &);
void LoadAlignmentSection(LineReader &,
std::vector<const AlignmentInfo *> &);
std::vector<const AlignmentInfo *> &,
std::vector<Phrase> &);
bool LoadRuleSection(LineReader &,
const std::vector<Word> &,

View File

@ -220,7 +220,7 @@ bool RuleTableLoaderStandard::Load(FormatType format
targetPhrase->SetSourcePhrase(sourcePhrase);
// rest of target phrase
targetPhrase->SetAlignmentInfo(alignString);
targetPhrase->SetAlignmentInfo(alignString, sourcePhrase);
targetPhrase->SetTargetLHS(targetLHS);
targetPhrase->SetRuleCount(ruleCountString, scoreVector);
//targetPhrase->SetDebugOutput(string("New Format pt ") + line);
@ -242,7 +242,6 @@ bool RuleTableLoaderStandard::Load(FormatType format
else
{ // do nothing
}
}
// sort and prune each target phrase collection

View File

@ -34,18 +34,23 @@ void SourceWordDeletionFeature::Evaluate(const Hypothesis& cur_hypo,
ScoreComponentCollection* accumulator) const
{
TargetPhrase targetPhrase = cur_hypo.GetCurrTargetPhrase();
ComputeFeatures(targetPhrase, accumulator);
const AlignmentInfo &alignmentInfo = targetPhrase.GetAlignmentInfo();
const AlignmentInfo::CollType &alignment = alignmentInfo.GetAlignments();
ComputeFeatures(targetPhrase, accumulator, alignment);
}
void SourceWordDeletionFeature::EvaluateChart(const ChartHypothesis& cur_hypo, int featureId,
ScoreComponentCollection* accumulator) const
{
TargetPhrase targetPhrase = cur_hypo.GetCurrTargetPhrase();
ComputeFeatures(targetPhrase, accumulator);
const AlignmentInfo &alignmentInfo = targetPhrase.GetAlignmentInfo();
const AlignmentInfo::CollType &alignment = alignmentInfo.GetTerminalAlignments();
ComputeFeatures(targetPhrase, accumulator, alignment);
}
void SourceWordDeletionFeature::ComputeFeatures(const TargetPhrase& targetPhrase,
ScoreComponentCollection* accumulator) const
ScoreComponentCollection* accumulator,
const AlignmentInfo::CollType &alignment) const
{
// handle special case: unknown words (they have no word alignment)
size_t targetLength = targetPhrase.GetSize();
@ -58,24 +63,26 @@ void SourceWordDeletionFeature::ComputeFeatures(const TargetPhrase& targetPhrase
}
// flag aligned words
const AlignmentInfo &alignment = targetPhrase.GetAlignmentInfo();
bool aligned[16];
CHECK(sourceLength < 16);
for(size_t i=0; i<sourceLength; i++)
aligned[i] = false;
for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++)
for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++)
aligned[ alignmentPoint->first ] = true;
// process unaligned source words
for(size_t i=0; i<sourceLength; i++) {
if (!aligned[i]) {
const string &word = targetPhrase.GetSourcePhrase().GetWord(i).GetFactor(m_factorType)->GetString();
if (word != "<s>" && word != "</s>") {
if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) {
accumulator->PlusEquals(this,"OTHER",1);
}
else {
accumulator->PlusEquals(this,word,1);
Word w = targetPhrase.GetSourcePhrase().GetWord(i);
if (!w.IsNonTerminal()) {
const string &word = w.GetFactor(m_factorType)->GetString();
if (word != "<s>" && word != "</s>") {
if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) {
accumulator->PlusEquals(this,"OTHER",1);
}
else {
accumulator->PlusEquals(this,word,1);
}
}
}
}

View File

@ -6,6 +6,7 @@
#include "FeatureFunction.h"
#include "FactorCollection.h"
#include "AlignmentInfo.h"
namespace Moses
{
@ -35,8 +36,9 @@ public:
int featureId,
ScoreComponentCollection* accumulator) const;
void ComputeFeatures(const TargetPhrase& targetPhrase,
ScoreComponentCollection* accumulator) const;
void ComputeFeatures(const TargetPhrase& targetPhrase,
ScoreComponentCollection* accumulator,
const AlignmentInfo::CollType &alignment) const;
// basic properties
std::string GetScoreProducerWeightShortName(unsigned) const { return "swd"; }

View File

@ -318,11 +318,38 @@ void TargetPhrase::SetAlignmentInfo(const StringPiece &alignString)
SetAlignmentInfo(alignmentInfo);
}
void TargetPhrase::SetAlignmentInfo(const StringPiece &alignString, Phrase &sourcePhrase)
{
std::vector<std::string> alignPoints;
boost::split(alignPoints, alignString, boost::is_any_of("\t "));
int indicator[alignPoints.size()];
int index = 0;
set<pair<size_t,size_t> > alignmentInfo;
for (util::TokenIter<util::AnyCharacter, true> token(alignString, util::AnyCharacter(" \t")); token; ++token) {
util::TokenIter<util::AnyCharacter, false> dash(*token, util::AnyCharacter("-"));
MosesShouldUseExceptions(dash);
size_t sourcePos = boost::lexical_cast<size_t>(*dash++);
MosesShouldUseExceptions(dash);
size_t targetPos = boost::lexical_cast<size_t>(*dash++);
MosesShouldUseExceptions(!dash);
alignmentInfo.insert(pair<size_t,size_t>(sourcePos, targetPos));
indicator[index++] = sourcePhrase.GetWord(sourcePos).IsNonTerminal() ? 1: 0;
}
SetAlignmentInfo(alignmentInfo, indicator);
}
void TargetPhrase::SetAlignmentInfo(const std::set<std::pair<size_t,size_t> > &alignmentInfo)
{
m_alignmentInfo = AlignmentInfoCollection::Instance().Add(alignmentInfo);
}
void TargetPhrase::SetAlignmentInfo(const std::set<std::pair<size_t,size_t> > &alignmentInfo, int* indicator)
{
m_alignmentInfo = AlignmentInfoCollection::Instance().Add(alignmentInfo, indicator);
}
TO_STRING_BODY(TargetPhrase);
@ -350,7 +377,7 @@ void TargetPhrase::SetRuleCount(const StringPiece &ruleCountString, std::vector<
}
else {
if (scoreVector.size() >= 1 ) p_f_given_e = scoreVector[0];
std::cerr << "Warning: possibly wrong format of phrase translation scores, number of scores: " << scoreVector.size() << endl;
// std::cerr << "Warning: possibly wrong format of phrase translation scores, number of scores: " << scoreVector.size() << endl;
}
targetCount = Scan<float>(tokens[0]);

View File

@ -152,7 +152,9 @@ public:
{ return m_lhsTarget; }
void SetAlignmentInfo(const StringPiece &alignString);
void SetAlignmentInfo(const std::set<std::pair<size_t,size_t> > &alignmentInfo);
void SetAlignmentInfo(const StringPiece &alignString, Phrase &sourcePhrase);
void SetAlignmentInfo(const std::set<std::pair<size_t,size_t> > &alignmentInfo);
void SetAlignmentInfo(const std::set<std::pair<size_t,size_t> > &alignmentInfo, int* indicator);
void SetAlignmentInfo(const AlignmentInfo *alignmentInfo) {
m_alignmentInfo = alignmentInfo;
}

View File

@ -34,19 +34,24 @@ void TargetWordInsertionFeature::Evaluate(const Hypothesis& cur_hypo,
ScoreComponentCollection* accumulator) const
{
const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
ComputeFeatures(targetPhrase, accumulator);
const AlignmentInfo &alignmentInfo = targetPhrase.GetAlignmentInfo();
const AlignmentInfo::CollType &alignment = alignmentInfo.GetAlignments();
ComputeFeatures(targetPhrase, accumulator, alignment);
}
void TargetWordInsertionFeature::EvaluateChart(const ChartHypothesis& cur_hypo,
int featureID,
ScoreComponentCollection* accumulator) const
int featureID,
ScoreComponentCollection* accumulator) const
{
const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
ComputeFeatures(targetPhrase, accumulator);
const AlignmentInfo &alignmentInfo = targetPhrase.GetAlignmentInfo();
const AlignmentInfo::CollType &alignment = alignmentInfo.GetTerminalAlignments();
ComputeFeatures(targetPhrase, accumulator, alignment);
}
void TargetWordInsertionFeature::ComputeFeatures(const TargetPhrase& targetPhrase,
ScoreComponentCollection* accumulator) const
ScoreComponentCollection* accumulator,
const AlignmentInfo::CollType &alignment) const
{
// handle special case: unknown words (they have no word alignment)
size_t targetLength = targetPhrase.GetSize();
@ -59,7 +64,6 @@ void TargetWordInsertionFeature::ComputeFeatures(const TargetPhrase& targetPhras
}
// flag aligned words
const AlignmentInfo &alignment = targetPhrase.GetAlignmentInfo();
bool aligned[16];
CHECK(targetLength < 16);
for(size_t i=0; i<targetLength; i++) {
@ -72,14 +76,17 @@ void TargetWordInsertionFeature::ComputeFeatures(const TargetPhrase& targetPhras
// process unaligned target words
for(size_t i=0; i<targetLength; i++) {
if (!aligned[i]) {
const string &word = targetPhrase.GetWord(i).GetFactor(m_factorType)->GetString();
if (word != "<s>" && word != "</s>") {
if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) {
Word w = targetPhrase.GetWord(i);
if (!w.IsNonTerminal()) {
const string &word = w.GetFactor(m_factorType)->GetString();
if (word != "<s>" && word != "</s>") {
if (!m_unrestricted && m_vocab.find( word ) == m_vocab.end()) {
accumulator->PlusEquals(this,"OTHER",1);
}
else {
}
else {
accumulator->PlusEquals(this,word,1);
}
}
}
}
}
}

View File

@ -6,6 +6,7 @@
#include "FeatureFunction.h"
#include "FactorCollection.h"
#include "AlignmentInfo.h"
namespace Moses
{
@ -23,7 +24,9 @@ public:
StatelessFeatureFunction("twi", ScoreProducer::unlimited),
m_factorType(factorType),
m_unrestricted(true)
{}
{
std::cerr << "Initializing target word insertion feature.." << std::endl;
}
bool Load(const std::string &filePath);
void Evaluate(const Hypothesis& cur_hypo,
@ -35,7 +38,8 @@ public:
ScoreComponentCollection* accumulator) const;
void ComputeFeatures(const TargetPhrase& targetPhrase,
ScoreComponentCollection* accumulator) const;
ScoreComponentCollection* accumulator,
const AlignmentInfo::CollType &alignment) const;
// basic properties
std::string GetScoreProducerWeightShortName(unsigned) const { return "twi"; }

View File

@ -192,9 +192,10 @@ void WordTranslationFeature::Evaluate(const Hypothesis& cur_hypo, ScoreComponent
void WordTranslationFeature::EvaluateChart(const ChartHypothesis& cur_hypo, int featureID,
ScoreComponentCollection* accumulator) const
{
const Sentence& input = *(m_local->input);
//const Sentence& input = *(m_local->input);
const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase();
const AlignmentInfo &alignment = targetPhrase.GetAlignmentInfo();
const AlignmentInfo &alignmentInfo = targetPhrase.GetAlignmentInfo();
const AlignmentInfo::CollType &alignment = alignmentInfo.GetTerminalAlignments();
// process aligned words
for (AlignmentInfo::const_iterator alignmentPoint = alignment.begin(); alignmentPoint != alignment.end(); alignmentPoint++) {

View File

@ -60,7 +60,7 @@ public:
m_sparseProducerWeight(1),
m_ignorePunctuation(ignorePunctuation)
{
std::cerr << "Creating word translation feature.. ";
std::cerr << "Initializing word translation feature.. ";
if (m_simple == 1) std::cerr << "using simple word translations.. ";
if (m_sourceContext == 1) std::cerr << "using source context.. ";
if (m_targetContext == 1) std::cerr << "using target context.. ";