added support tp specify translation options via xml for chart decoder

This commit is contained in:
phikoehn 2012-03-28 04:29:24 +01:00
parent 9e5e502687
commit 292c75cb1a
4 changed files with 131 additions and 10 deletions

View File

@ -29,6 +29,7 @@
#include "ChartTrellisPathList.h"
#include "StaticData.h"
#include "DecodeStep.h"
#include "TreeInput.h"
using namespace std;
using namespace Moses;
@ -79,6 +80,8 @@ void ChartManager::ProcessSentence()
VERBOSE(2,"Decoding: " << endl);
//ChartHypothesis::ResetHypoCount();
AddXmlChartOptions();
// MAIN LOOP
size_t size = m_source.GetSize();
for (size_t width = 1; width <= size; ++width) {
@ -122,6 +125,28 @@ void ChartManager::ProcessSentence()
}
}
void ChartManager::AddXmlChartOptions() {
TreeInput const &source = dynamic_cast<TreeInput const&>(m_source);
const std::vector <ChartTranslationOption*> xmlChartOptionsList = source.GetXmlChartTranslationOptions();
IFVERBOSE(2) { cerr << "AddXmlChartOptions " << xmlChartOptionsList.size() << endl; }
if (xmlChartOptionsList.size() == 0) return;
for(std::vector<ChartTranslationOption*>::const_iterator i = xmlChartOptionsList.begin();
i != xmlChartOptionsList.end(); ++i) {
ChartTranslationOption* opt = *i;
Moses::Scores wordPenaltyScore(1, -0.434294482); // TODO what is this number?
opt->GetTargetPhraseCollection().GetCollection()[0]->SetScore((ScoreProducer*)m_system->GetWordPenaltyProducer(), wordPenaltyScore);
const WordsRange &range = opt->GetSourceWordsRange();
RuleCubeItem* item = new RuleCubeItem( *opt, m_hypoStackColl );
ChartHypothesis* hypo = new ChartHypothesis(*opt, *item, *this);
hypo->CalcScore();
ChartCell &cell = m_hypoStackColl.Get(range);
cell.AddHypothesis(hypo);
}
}
const ChartHypothesis *ChartManager::GetBestHypothesis() const
{
size_t size = m_source.GetSize();

View File

@ -65,6 +65,7 @@ public:
ChartManager(InputType const& source, const TranslationSystem* system);
~ChartManager();
void ProcessSentence();
void AddXmlChartOptions();
const ChartHypothesis *GetBestHypothesis() const;
void CalcNBest(size_t count, ChartTrellisPathList &ret,bool onlyDistinct=0) const;

View File

@ -20,7 +20,7 @@ namespace Moses
* \param reorderingConstraint reordering constraint zones specified by xml
* \param walls reordering constraint walls specified by xml
*/
bool TreeInput::ProcessAndStripXMLTags(string &line, std::vector<XMLParseOutput> &sourceLabels)
bool TreeInput::ProcessAndStripXMLTags(string &line, std::vector<XMLParseOutput> &sourceLabels, std::vector<XmlOption*> &xmlOptions)
{
//parse XML markup in translation line
@ -41,6 +41,10 @@ bool TreeInput::ProcessAndStripXMLTags(string &line, std::vector<XMLParseOutput>
string cleanLine; // return string (text without xml)
size_t wordPos = 0; // position in sentence (in terms of number of words)
// keep this handy for later
const vector<FactorType> &outputFactorOrder = StaticData::Instance().GetOutputFactorOrder();
const string &factorDelimiter = StaticData::Instance().GetFactorDelimiter();
// loop through the tokens
for (size_t xmlTokenPos = 0 ; xmlTokenPos < xmlTokens.size() ; xmlTokenPos++) {
// not a xml tag, but regular text (may contain many words)
@ -145,14 +149,63 @@ bool TreeInput::ProcessAndStripXMLTags(string &line, std::vector<XMLParseOutput>
return false;
}
WordsRange range(startPos,endPos-1);
// specified translations -> vector of phrases
// multiple translations may be specified, separated by "||"
vector<string> altTexts = TokenizeMultiCharSeparator(ParseXmlTagAttribute(tagContent,"label"), "||");
CHECK(altTexts.size() == 1);
// may be either a input span label ("label"), or a specified output translation "translation"
string label = ParseXmlTagAttribute(tagContent,"label");
string translation = ParseXmlTagAttribute(tagContent,"translation");
XMLParseOutput item(altTexts[0], range);
sourceLabels.push_back(item);
// specified label
if (translation.length() == 0 && label.length() > 0) {
WordsRange range(startPos,endPos-1); // really?
XMLParseOutput item(label, range);
sourceLabels.push_back(item);
}
// specified translations -> vector of phrases, separated by "||"
if (translation.length() > 0 && StaticData::Instance().GetXmlInputType() != XmlIgnore) {
vector<string> altTexts = TokenizeMultiCharSeparator(translation, "||");
vector<string> altLabel = TokenizeMultiCharSeparator(label, "||");
vector<string> altProbs = TokenizeMultiCharSeparator(ParseXmlTagAttribute(tagContent,"prob"), "||");
//TRACE_ERR("number of translations: " << altTexts.size() << endl);
for (size_t i=0; i<altTexts.size(); ++i) {
// set target phrase
TargetPhrase targetPhrase(Output);
targetPhrase.CreateFromString(outputFactorOrder,altTexts[i],factorDelimiter);
// set constituent label
string targetLHSstr;
if (altLabel.size() > i && altLabel[i].size() > 0) {
targetLHSstr = altLabel[i];
}
else {
const UnknownLHSList &lhsList = StaticData::Instance().GetUnknownLHS();
UnknownLHSList::const_iterator iterLHS = lhsList.begin();
targetLHSstr = iterLHS->first;
}
Word targetLHS(true);
targetLHS.CreateFromString(Output, outputFactorOrder, targetLHSstr, true);
CHECK(targetLHS.GetFactor(0) != NULL);
targetPhrase.SetTargetLHS(targetLHS);
// get probability
float probValue = 1;
if (altProbs.size() > i && altProbs[i].size() > 0) {
probValue = Scan<float>(altProbs[i]);
}
// convert from prob to log-prob
float scoreValue = FloorScore(TransformScore(probValue));
targetPhrase.SetScore(scoreValue);
// set span and create XmlOption
WordsRange range(startPos+1,endPos);
XmlOption *option = new XmlOption(range,targetPhrase);
CHECK(option);
xmlOptions.push_back(option);
VERBOSE(2,"xml translation = [" << range << "] " << targetLHSstr << " -> " << altTexts[i] << " prob: " << probValue << endl);
}
altTexts.clear();
altProbs.clear();
}
}
}
}
@ -179,7 +232,8 @@ int TreeInput::Read(std::istream& in,const std::vector<FactorType>& factorOrder)
//line = Trim(line);
std::vector<XMLParseOutput> sourceLabels;
ProcessAndStripXMLTags(line, sourceLabels);
std::vector<XmlOption*> xmlOptionsList;
ProcessAndStripXMLTags(line, sourceLabels, xmlOptionsList);
// do words 1st - hack
stringstream strme;
@ -211,6 +265,42 @@ int TreeInput::Read(std::istream& in,const std::vector<FactorType>& factorOrder)
}
}
// XML Options
//only fill the vector if we are parsing XML
if (staticData.GetXmlInputType() != XmlPassThrough ) {
//TODO: needed to handle exclusive
//for (size_t i=0; i<GetSize(); i++) {
// m_xmlCoverageMap.push_back(false);
//}
//iterXMLOpts will be empty for XmlIgnore
//look at each column
for(std::vector<XmlOption*>::const_iterator iterXmlOpts = xmlOptionsList.begin();
iterXmlOpts != xmlOptionsList.end(); iterXmlOpts++) {
const XmlOption *xmlOption = *iterXmlOpts;
TargetPhrase *targetPhrase = new TargetPhrase(xmlOption->targetPhrase);
*targetPhrase = xmlOption->targetPhrase; // copy everything
WordsRange *range = new WordsRange(xmlOption->range);
const StackVec emptyStackVec; // hmmm... maybe dangerous, but it is never consulted
TargetPhraseCollection *tpc = new TargetPhraseCollection;
tpc->Add(targetPhrase);
ChartTranslationOption *transOpt = new ChartTranslationOption(*tpc, emptyStackVec, *range, 0.0f);
m_xmlChartOptionsList.push_back(transOpt);
//TODO: needed to handle exclusive
//for(size_t j=transOpt->GetSourceWordsRange().GetStartPos(); j<=transOpt->GetSourceWordsRange().GetEndPos(); j++) {
// m_xmlCoverageMap[j]=true;
//}
delete xmlOption;
}
}
return 1;
}

View File

@ -4,6 +4,7 @@
#include <vector>
#include "Sentence.h"
#include "ChartTranslationOption.h"
namespace Moses
{
@ -25,6 +26,7 @@ class TreeInput : public Sentence
protected:
std::vector<std::vector<NonTerminalSet> > m_sourceChart;
std::vector <ChartTranslationOption*> m_xmlChartOptionsList;
void AddChartLabel(size_t startPos, size_t endPos, const std::string &label
,const std::vector<FactorType>& factorOrder);
@ -34,7 +36,7 @@ protected:
return m_sourceChart[startPos][endPos - startPos];
}
bool ProcessAndStripXMLTags(std::string &line, std::vector<XMLParseOutput> &sourceLabels);
bool ProcessAndStripXMLTags(std::string &line, std::vector<XMLParseOutput> &sourceLabels, std::vector<XmlOption*> &res);
public:
TreeInput()
@ -57,6 +59,9 @@ public:
return m_sourceChart[startPos][endPos - startPos];
}
std::vector <ChartTranslationOption*> GetXmlChartTranslationOptions() const {
return m_xmlChartOptionsList;
};
};
}