modify mosesserver to allow weight updates, change BleuFeature compare function, add debug code to Hypothesis.cpp and FeatureVector.cpp, fix OutputPhraseNode in Manager.cpp, add client script for translation with weight updates

This commit is contained in:
Eva 2012-05-06 16:04:55 -07:00
parent ea6212432a
commit f147031576
6 changed files with 234 additions and 44 deletions

View File

@ -15,6 +15,7 @@
#ifdef LM_ORLM
# include "LanguageModelORLM.h"
#endif
#include <boost/algorithm/string.hpp>
using namespace Moses;
using namespace std;
@ -149,10 +150,8 @@ public:
this->_help = "Does translation";
}
void
execute(xmlrpc_c::paramList const& paramList,
void execute(xmlrpc_c::paramList const& paramList,
xmlrpc_c::value * const retvalP) {
const params_t params = paramList.getStruct(0);
paramList.verifyEnd(1);
params_t::const_iterator si = params.find("text");
@ -164,7 +163,7 @@ public:
const string source(
(xmlrpc_c::value_string(si->second)));
cerr << "Input: " << source << endl;
cerr << "Input: " << source;
si = params.find("align");
bool addAlignInfo = (si != params.end());
si = params.find("sg");
@ -198,7 +197,7 @@ public:
map<string, xmlrpc_c::value> retData;
pair<string, xmlrpc_c::value>
text("text", xmlrpc_c::value_string(out.str()));
cerr << "Output: " << out.str() << endl;
cerr << "Output: " << out.str() << endl << endl;
if (addAlignInfo) {
retData.insert(pair<string, xmlrpc_c::value>("align", xmlrpc_c::value_array(alignInfo)));
}
@ -262,8 +261,12 @@ public:
}
searchGraphXmlNode["cover-start"] = xmlrpc_c::value_int(hypo->GetCurrSourceWordsRange().GetStartPos());
searchGraphXmlNode["cover-end"] = xmlrpc_c::value_int(hypo->GetCurrSourceWordsRange().GetEndPos());
searchGraphXmlNode["out"] =
xmlrpc_c::value_string(hypo->GetCurrTargetPhrase().GetStringRep(StaticData::Instance().GetOutputFactorOrder()));
stringstream tmp;
tmp << hypo->GetSourcePhraseStringRep() << "|" << hypo->GetCurrTargetPhrase().GetStringRep(StaticData::Instance().GetOutputFactorOrder());
searchGraphXmlNode["out"] = xmlrpc_c::value_string(tmp.str());
tmp.str("");
tmp << hypo->GetScoreBreakdown();
searchGraphXmlNode["scores"] = xmlrpc_c::value_string(tmp.str());
}
searchGraphXml.push_back(xmlrpc_c::value_struct(searchGraphXmlNode));
}
@ -290,10 +293,11 @@ public:
toptXml["start"] = xmlrpc_c::value_int(startPos);
toptXml["end"] = xmlrpc_c::value_int(endPos);
vector<xmlrpc_c::value> scoresXml;
cerr << "Warning: not adding scores to translation options.." << endl;
ScoreComponentCollection scores = topt->GetScoreBreakdown();
for (size_t j = 0; j < scores.size(); ++j) {
scoresXml.push_back(xmlrpc_c::value_double(scores[j]));
}
// for (size_t j = 0; j < scores.size(); ++j) {
// scoresXml.push_back(xmlrpc_c::value_double(scores[j]));
// }
toptXml["scores"] = xmlrpc_c::value_array(scoresXml);
toptsXml.push_back(xmlrpc_c::value_struct(toptXml));
}
@ -301,11 +305,122 @@ public:
}
retData.insert(pair<string, xmlrpc_c::value>("topt", xmlrpc_c::value_array(toptsXml)));
}
};
/*const char* ffNames[] = { "Distortion", "WordPenalty", "!UnknownWordPenalty 1", "LexicalReordering_wbe-msd-bidirectional-fe-allff_1",
"LexicalReordering_wbe-msd-bidirectional-fe-allff_2", "LexicalReordering_wbe-msd-bidirectional-fe-allff_3",
"LexicalReordering_wbe-msd-bidirectional-fe-allff_4", "LexicalReordering_wbe-msd-bidirectional-fe-allff_5",
"LexicalReordering_wbe-msd-bidirectional-fe-allff_6", "LM", "PhraseModel_1", "PhraseModel_2", "PhraseModel_3",
"PhraseModel_4", "PhraseModel_5" };*/
const char* ffNames[] = { "Distortion", "WordPenalty", "!UnknownWordPenalty 1", "LM", "PhraseModel_1", "PhraseModel_2" };
class WeightUpdater: public xmlrpc_c::method
{
public:
WeightUpdater() {
// signature and help strings are documentation -- the client
// can query this information with a system.methodSignature and
// system.methodHelp RPC.
this->_signature = "S:S";
this->_help = "Updates Moses weights";
}
void execute(xmlrpc_c::paramList const& paramList,
xmlrpc_c::value * const retvalP) {
const params_t params = paramList.getStruct(0);
paramList.verifyEnd(1);
ScoreComponentCollection updatedWeights;
params_t::const_iterator si = params.find("core-weights");
string coreWeights;
if (si == params.end()) {
throw xmlrpc_c::fault(
"Missing core weights",
xmlrpc_c::fault::CODE_PARSE);
}
coreWeights = xmlrpc_c::value_string(si->second);
VERBOSE(1, "core weights: " << coreWeights << endl);
StaticData &staticData = StaticData::InstanceNonConst();
const vector<const ScoreProducer*> featureFunctions = staticData.GetTranslationSystem(TranslationSystem::DEFAULT).GetFeatureFunctions();
vector<string> coreWeightVector;
boost::split(coreWeightVector, coreWeights, boost::is_any_of(","));
loadCoreWeight(updatedWeights, coreWeightVector, featureFunctions);
si = params.find("sparse-weights");
string sparseWeights;
if (si != params.end()) {
sparseWeights = xmlrpc_c::value_string(si->second);
VERBOSE(1, "sparse weights: " << sparseWeights << endl);
vector<string> sparseWeightVector;
boost::split(sparseWeightVector, sparseWeights, boost::is_any_of("\t "));
for(size_t i=0; i<sparseWeightVector.size(); ++i) {
vector<string> name_value;
boost::split(name_value, sparseWeightVector[i], boost::is_any_of("="));
if (name_value.size() > 2) {
string tmp1 = name_value[name_value.size()-1];
name_value.erase(name_value.end());
string tmp2 = boost::algorithm::join(name_value, "=");
name_value[0] = tmp2;
name_value[1] = tmp1;
}
const string name(name_value[0]);
float value = Scan<float>(name_value[1]);
VERBOSE(1, "Setting sparse weight " << name << " to value " << value << "." << endl);
updatedWeights.Assign(name, value);
}
}
staticData.SetAllWeights(updatedWeights);
cerr << "\nUpdated weights: " << staticData.GetAllWeights() << endl;
*retvalP = xmlrpc_c::value_string("Weights updated!");
}
bool loadCoreWeight(ScoreComponentCollection &weights, vector<string> &coreWeightVector, const vector<const ScoreProducer*> &featureFunctions) {
vector< float > store_weights;
for (size_t i=0; i<coreWeightVector.size(); ++i) {
string name(ffNames[i]);
float weight = Scan<float>(coreWeightVector[i]);
VERBOSE(1, "loading core weight " << name << endl);
for (size_t i=0; i < featureFunctions.size(); ++i) {
std::string prefix = featureFunctions[i]->GetScoreProducerDescription();
if (name.substr( 0, prefix.length() ).compare( prefix ) == 0) {
size_t numberScoreComponents = featureFunctions[i]->GetNumScoreComponents();
if (numberScoreComponents == 1) {
VERBOSE(1, "assign 1 weight for " << featureFunctions[i]->GetScoreProducerDescription());
VERBOSE(1, " (" << weight << ")" << endl << endl);
weights.Assign(featureFunctions[i], weight);
}
else {
store_weights.push_back(weight);
if (store_weights.size() == numberScoreComponents) {
VERBOSE(1, "assign " << store_weights.size() << " weights for " << featureFunctions[i]->GetScoreProducerDescription() << " (");
for (size_t j=0; j < store_weights.size(); ++j)
VERBOSE(1, store_weights[j] << " ");
VERBOSE(1, ")" << endl << endl);
weights.Assign(featureFunctions[i], store_weights);
store_weights.clear();
}
}
}
}
}
return true;
}
};
/**
* Allocates a char* and copies string into it.
**/
static char* strToChar(const string& s) {
char* c = new char[s.size()+1];
strcpy(c,s.c_str());
return c;
}
int main(int argc, char** argv)
{
@ -344,6 +459,16 @@ int main(int argc, char** argv)
}
}
cerr << "Switching off translation option cache.." << endl;
mosesargv[mosesargc] = strToChar("-use-persistent-cache");
++mosesargc;
mosesargv[mosesargc] = strToChar("0");
++mosesargc;
mosesargv[mosesargc] = strToChar("-persistent-cache-size");
++mosesargc;
mosesargv[mosesargc] = strToChar("0");
++mosesargc;
Parameter* params = new Parameter();
if (!params->LoadParam(mosesargc,mosesargv)) {
params->Explain();
@ -353,13 +478,17 @@ int main(int argc, char** argv)
exit(1);
}
cerr << "start weights: " << StaticData::Instance().GetAllWeights() << endl;
xmlrpc_c::registry myRegistry;
xmlrpc_c::methodPtr const translator(new Translator);
xmlrpc_c::methodPtr const updater(new Updater);
xmlrpc_c::methodPtr const weightUpdater(new WeightUpdater);
myRegistry.addMethod("translate", translator);
myRegistry.addMethod("updater", updater);
myRegistry.addMethod("updateWeights", weightUpdater);
xmlrpc_c::serverAbyss myAbyssServer(
myRegistry,

View File

@ -0,0 +1,71 @@
#!/usr/bin/perl -w
use strict;
use Encode;
use Frontier::Client;
#my $proxy = XMLRPC::Lite->proxy($url);
my $port = "8080";
my $url = "http://localhost:".$port."/RPC2";
my $server = Frontier::Client->new('url' => $url, 'encoding' => 'UTF-8');
my $verbose=0;
my $translations="translations.out";
open(TR, ">$translations");
my $sg_out="seachGraph.out";
open(SG, ">$sg_out");
#for (my $i=0;$i<scalar(@SENTENCE);$i++)
my $i=0;
while (my $text = <STDIN>)
{
my $date = `date`;
chop($date);
print "[$date] sentence $i: translate\n" if $verbose;
# update weights
#my $core_weights = "0.031,-0.138,1.000,0.087,0.035,0.105,0.061,0.052,0.114,0.095,0.064,0.039,0.056,0.043,0.081";
#my $core_weights = "0.031,-0.138,1.000,0.095,0.064,0.039";
my $core_weights = "0.001,-0.001,1.000,1,1,1";
my $sparse_weights = "pp_europea~European=0.015 pp_es~is=0.03 pp_es,=~is,equal=0.0001";
my %param = ("core-weights" => $core_weights, "sparse-weights" => $sparse_weights);
$server->call("updateWeights",(\%param));
# translate
#my %param = ("text" => $server->string($SENTENCE[$i]) , "sg" => "true");
%param = ("text" => $text, "sg" => "true");
my $result = $server->call("translate",(\%param));
$date = `date`;
chop($date);
print "[$date] sentence $i: process translation\n" if $verbose;
# process translation
my $mt_output = Encode::encode('UTF-8',$result->{'text'}); # no idea why that is necessary
$mt_output =~ s/\|\S+//g; # no multiple factors, only first
print "sentence $i >> $translations \n";
print TR $mt_output."\n";
# print out search graph
print "sentence $i >> $sg_out \n";
my $sg_ref = $result->{'sg'};
foreach my $sgn (@$sg_ref) {
# print out in extended format
if ($sgn->{hyp} eq 0) {
print SG "$i hyp=$sgn->{hyp} stack=$sgn->{stack} forward=$sgn->{forward} fscore=$sgn->{fscore} \n";
}
else {
print SG "$i hyp=$sgn->{hyp} stack=$sgn->{stack} back=$sgn->{back} score=$sgn->{score} transition=$sgn->{transition} ";
if ($sgn->{"recombined"}) {
print SG "recombined=$sgn->{recombined} ";
}
print SG "forward=$sgn->{forward} fscore=$sgn->{fscore} covered=$sgn->{'cover-start'}-$sgn->{'cover-end'} ";
print SG "scores=\"$sgn->{scores}\" out=\"$sgn->{out}\" \n";
}
}
++$i;
}
close(SG);

View File

@ -24,21 +24,6 @@ int BleuScoreState::Compare(const FFState& o) const
const BleuScoreState& other = dynamic_cast<const BleuScoreState&>(o);
if (m_source_length < other.m_source_length)
return -1;
if (m_source_length > other.m_source_length)
return 1;
if (m_target_length < other.m_target_length)
return -1;
if (m_target_length > other.m_target_length)
return 1;
if (m_scaled_ref_length < other.m_scaled_ref_length)
return -1;
if (m_scaled_ref_length > other.m_scaled_ref_length)
return 1;
int c = m_words.Compare(other.m_words);
if (c != 0)

View File

@ -176,6 +176,7 @@ namespace Moses {
linestream >> namestring;
linestream >> value;
FName fname(namestring);
cerr << "Setting sparse weight " << fname << " to value " << value << "." << endl;
set(fname,value);
}
return true;

View File

@ -248,6 +248,10 @@ int Hypothesis::RecombineCompare(const Hypothesis &compare) const
if (comp != 0) return comp;
}
cerr << "Recombining hypotheses.. " << endl;
cerr << "1: " << *this << endl;
cerr << "2: " << compare << endl << endl;
return 0;
}

View File

@ -642,7 +642,7 @@ void OutputSearchNode(long translationId, std::ostream &outputSearchGraphStream,
// special case: initial hypothesis
if ( searchNode.hypo->GetId() == 0 ) {
outputSearchGraphStream << " hyp=0 stack=0";
if (!extendedFormat) {
if (extendedFormat) {
outputSearchGraphStream << " forward=" << searchNode.forward << " fscore=" << searchNode.fscore;
}
outputSearchGraphStream << endl;
@ -671,9 +671,9 @@ void OutputSearchNode(long translationId, std::ostream &outputSearchGraphStream,
}
// output in extended format
if (searchNode.recombinationHypo != NULL)
outputSearchGraphStream << " hyp=" << searchNode.recombinationHypo->GetId();
else
// if (searchNode.recombinationHypo != NULL)
// outputSearchGraphStream << " hyp=" << searchNode.recombinationHypo->GetId();
// else
outputSearchGraphStream << " hyp=" << searchNode.hypo->GetId();
outputSearchGraphStream << " stack=" << searchNode.hypo->GetWordsBitmap().GetNumWordsCovered()