merged extract.cpp

This commit is contained in:
Marine Carpuat 2012-07-12 17:38:52 -04:00
commit 390cfce5aa
13 changed files with 170 additions and 549 deletions

View File

@ -4,8 +4,4 @@ obj TaggedCorpus.o : TaggedCorpus.cpp header_paths ;
obj PsdPhraseUtils.o : PsdPhraseUtils.cpp header_paths ;
exe extract-psd : psd_extract_features.cpp ..//..//psd//psd ..//tables-core ..//InputFileStream ..//SafeGetline.h PsdPhraseUtils.o header_paths ;
exe make-psd-table : psd_make_phrasetable.cpp ..//tables-core ..//InputFileStream ..//SafeGetline.h PsdPhraseUtils.o header_paths ;
exe tag-test-psd : psd_extract_test.cpp ..//tables-core ..//InputFileStream ..//SafeGetline.h PsdPhraseUtils.o header_paths ;
exe extract-psd : psd_extract_features.cpp TaggedCorpus.o ..//..//psd//psd ..//tables-core ..//InputFileStream ..//SafeGetline.h PsdPhraseUtils.o header_paths ;

View File

@ -42,11 +42,11 @@ PHRASE_ID getPhraseID(const string phrase, Vocabulary &wordVocab, PhraseVocab &v
string getPhrase(PHRASE_ID labelid, Vocabulary &tgtVocab, PhraseVocab &tgtPhraseVoc){
PHRASE p = tgtPhraseVoc.getPhrase(labelid);
string phrase = "";
for(int i = 0; i < p.size(); ++i){
if (phrase != ""){
phrase += " ";
}
phrase += tgtVocab.getWord(p[i]);
for(size_t i = 0; i < p.size(); ++i){
if (phrase != ""){
phrase += " ";
}
phrase += tgtVocab.getWord(p[i]);
}
return phrase;
}
@ -60,88 +60,49 @@ bool readPhraseVocab(const char* vocabFile, Vocabulary &wordVocab, PhraseVocab &
SAFE_GETLINE(file, line, LINE_MAX_LENGTH, '\n', __FILE__);
if (file.eof()) return true;
PHRASE phrase = makePhraseAndVoc(string(line),wordVocab);
int pid = vocab.storeIfNew(phrase);
vocab.storeIfNew(phrase);
}
return true;
}
bool readPhraseTranslations(const char *ptFile, Vocabulary &srcWordVocab, Vocabulary &tgtWordVocab, PhraseVocab &srcPhraseVocab, PhraseVocab &tgtPhraseVocab, PhraseTranslations &transTable){
InputFileStream file(ptFile);
if(!file) return false;
while(!file.eof()){
char line[LINE_MAX_LENGTH];
SAFE_GETLINE(file, line, LINE_MAX_LENGTH, '\n', __FILE__);
if (file.eof()) return true;
vector<string> fields = TokenizeMultiCharSeparator(string(line), " ||| ");
// cerr << "TOKENIZED: " << fields.size() << " tokens in " << line << endl;
if (fields.size() < 2){
cerr << "Skipping malformed phrase-table entry: " << line << endl;
}
PHRASE_ID src = getPhraseID(fields[0],srcWordVocab,srcPhraseVocab);
PHRASE_ID tgt = getPhraseID(fields[1],tgtWordVocab,tgtPhraseVocab);
if (src && tgt){
PhraseTranslations::iterator itr = transTable.find(src);
if (itr == transTable.end() ){
map<PHRASE_ID,int> tgts;
tgts.insert(make_pair (tgt,0));
transTable.insert(make_pair (src,tgts));
}else{
map<PHRASE_ID,int>::iterator itr2 = itr->second.find(tgt);
if (itr2 == itr->second.end()){
itr->second.insert(make_pair(tgt,itr->second.size()));
}
}
}
/* }else{
cerr << "Skipping phrase-table entry due to OOV phrase: " << line << endl;
}
*/
InputFileStream file(ptFile);
if(!file) return false;
while(!file.eof()){
char line[LINE_MAX_LENGTH];
SAFE_GETLINE(file, line, LINE_MAX_LENGTH, '\n', __FILE__);
if (file.eof()) return true;
vector<string> fields = TokenizeMultiCharSeparator(string(line), " ||| ");
// cerr << "TOKENIZED: " << fields.size() << " tokens in " << line << endl;
if (fields.size() < 2){
cerr << "Skipping malformed phrase-table entry: " << line << endl;
}
return true;
PHRASE_ID src = getPhraseID(fields[0],srcWordVocab,srcPhraseVocab);
PHRASE_ID tgt = getPhraseID(fields[1],tgtWordVocab,tgtPhraseVocab);
if (src && tgt){
transTable.insert(make_pair(src, tgt));
}
}
/* }else{
cerr << "Skipping phrase-table entry due to OOV phrase: " << line << endl;
}
*/
return true;
}
bool exists(PHRASE_ID src, PHRASE_ID tgt, PhraseTranslations &transTable){
PhraseTranslations::iterator itr = transTable.find(src);
if (itr == transTable.end()) return false;
map<PHRASE_ID,int>::iterator itr2 = (itr->second).find(tgt);
if (itr2 == (itr->second).end()) return false;
return true;
PhraseTranslations::const_iterator it;
for (it = transTable.lower_bound(src); it != transTable.upper_bound(src); it++) {
if (it->second == tgt)
return true;
}
return false;
}
bool exists(PHRASE_ID src, PhraseTranslations &transTable){
return (transTable.find(src) != transTable.end());
}
bool printTransToFile(string fileName, PhraseTranslations &transTable, PHRASE_ID src){
ofstream out(fileName.c_str());
PhraseTranslations::iterator itr = transTable.find(src);
if (itr == transTable.end()) return false;
for(map<PHRASE_ID,int>::iterator itr2 = (itr->second).begin(); itr2 != (itr->second).end(); itr2++){
out << itr2->second << " " << itr2->first << endl;
}
return true;
}
bool printTransStringToFile(string fileName, PhraseTranslations &transTable, PHRASE_ID src, PhraseVocab &pVocab, Vocabulary &wVocab){
ofstream out(fileName.c_str());
PhraseTranslations::iterator itr = transTable.find(src);
if (itr == transTable.end()) return false;
for(map<PHRASE_ID,int>::iterator itr2 = (itr->second).begin(); itr2 != (itr->second).end(); itr2++){
PHRASE p = pVocab.getPhrase(itr2->first);
if (p.size() > 0){
string phrase = wVocab.getWord(p[0]);
for(int i = 1; i < p.size(); i++){
phrase = phrase + " " + wVocab.getWord(p[i]);
}
out << itr2->second << " " << phrase << endl;
}else{
return false;
}
}
return true;
}
bool readPhraseTranslations(const char *ptFile, Vocabulary &srcWordVocab, Vocabulary &tgtWordVocab, PhraseVocab &srcPhraseVocab, PhraseVocab &tgtPhraseVocab, PhraseTranslations &transTable, map<string,string> &transTableScores){
InputFileStream file(ptFile);
if(!file) return false;
@ -156,19 +117,9 @@ bool readPhraseTranslations(const char *ptFile, Vocabulary &srcWordVocab, Vocabu
PHRASE_ID src = getPhraseID(fields[0],srcWordVocab,srcPhraseVocab);
PHRASE_ID tgt = getPhraseID(fields[1],tgtWordVocab,tgtPhraseVocab);
if (src && tgt){
PhraseTranslations::iterator itr = transTable.find(src);
transTable.insert(make_pair(src, tgt));
string stpair = SPrint(src)+" "+SPrint(tgt);
transTableScores.insert(make_pair (stpair,fields[2]));
if (itr == transTable.end() ){
map<PHRASE_ID,int> tgts;
tgts.insert(make_pair (tgt,0));
transTable.insert(make_pair (src,tgts));
}else{
map<PHRASE_ID,int>::iterator itr2 = itr->second.find(tgt);
if (itr2 == itr->second.end()){
itr->second.insert(make_pair(tgt,itr->second.size()));
}
}
}
/* }else{
cerr << "Skipping phrase-table entry due to OOV phrase: " << line << endl;

View File

@ -10,7 +10,7 @@
using namespace std;
typedef MosesTraining::PhraseTable PhraseVocab;
typedef map< MosesTraining::PHRASE_ID, map< MosesTraining::PHRASE_ID, int > > PhraseTranslations;
typedef multimap< MosesTraining::PHRASE_ID, MosesTraining::PHRASE_ID > PhraseTranslations;
bool readPhraseVocab(const char* vocabFile, MosesTraining::Vocabulary &wordVocab, PhraseVocab &vocab);
@ -29,9 +29,4 @@ bool exists(MosesTraining::PHRASE_ID src, MosesTraining::PHRASE_ID tgt, PhraseTr
bool exists(MosesTraining::PHRASE_ID src, PhraseTranslations &transTable);
bool printTransToFile(string fileName, PhraseTranslations &transTable, MosesTraining::PHRASE_ID src);
bool printTransStringToFile(string fileName, PhraseTranslations &transTable, MosesTraining::PHRASE_ID src, PhraseVocab &pVocab, MosesTraining::Vocabulary &wVocab);
#endif

View File

@ -1,7 +1,7 @@
#ifndef _PSD_H_
#define _PSD_H_
enum CLASSIFIER_TYPE {MEGAM,VW,RAW};
enum PSD_MODEL_TYPE {PHRASAL,GLOBAL};
enum CLASSIFIER_TYPE {MEGAM, VWFile, VWLib};
enum PSD_MODEL_TYPE {PHRASAL, GLOBAL};
#endif

View File

@ -16,22 +16,23 @@
#include "PsdPhraseUtils.h"
#include "FeatureExtractor.h"
#include "FeatureConsumer.h"
#include "TaggedCorpus.h"
using namespace std;
using namespace Moses;
using namespace MosesTraining;
using namespace boost::bimaps;
using namespace PSD;
#define LINE_MAX_LENGTH 10000
// globals
CLASSIFIER_TYPE psd_classifier = RAW;
CLASSIFIER_TYPE psd_classifier = VWFile;
PSD_MODEL_TYPE psd_model = GLOBAL;
string ptDelim = " ||| ";
string factorDelim = "|";
int subdirsize=1000;
string ext = ".contexts";
string labelext = ".labels";
bool psd_train_mode = true;
Vocabulary srcVocab;
Vocabulary tgtVocab;
@ -42,7 +43,7 @@ int main(int argc,char* argv[]){
cerr << "syntax: extract-psd context.template corpus.psd corpus.raw corpus.factored phrase-table sourcePhraseVocab targetPhraseVocab outputdir/filename [options]\n";
cerr << endl;
cerr << "Options:" << endl;
cerr << "\t --ClassifierType vw|megam|none" << endl;
cerr << "\t --ClassifierType vw|megam" << endl;
cerr << "\t --PsdType phrasal|global" << endl;
exit(1);
}
@ -58,12 +59,8 @@ int main(int argc,char* argv[]){
if (strcmp(argv[i],"--ClassifierType") == 0){
char* format = argv[++i];
if (strcmp(format,"vw") == 0){
psd_classifier = VW;
psd_classifier = VWFile;
ext = ".vw-data";
labelext = ".vw-header";
}else if (strcmp(format,"none") == 0){
psd_classifier = RAW;
ext = ".context";
}else if (strcmp(format,"megam") == 0){
psd_classifier = MEGAM;
ext = ".megam";
@ -119,54 +116,45 @@ int main(int argc,char* argv[]){
int i = 0;
int csid = 0;
// prep output files for PHRASAL setting
map<PHRASE_ID, ostream*> outFiles;
// get label info and prep output data file for GLOBAL setting
ofstream* globalDataOut = new ofstream();
//OutputFileStream* globalDataOut;
if (psd_model == GLOBAL){
//globalDataOut = new OutputFileStream(output+ext);
globalDataOut->open((output+ext).c_str());
string labelName = output+labelext;
if (psd_train_mode && ! printVwHeaderFile(labelName,transTable,tgtPhraseVoc,tgtVocab)){
cerr << "Failed to write file " << labelName << endl;
}
// create target phrase index for feature extractor
TargetIndexType extractorTargetIndex;
for (size_t i = 0; i < tgtPhraseVoc.phraseTable.size(); i++) {
extractorTargetIndex.insert(TargetIndexType::value_type(getPhrase(i, tgtVocab, tgtPhraseVoc), i));
}
FeatureExtractor extractor(extractorTargetIndex, true);
// prep feature consumers for PHRASAL setting
map<PHRASE_ID, FeatureConsumer*> consumers;
// feature consumer for GLOBAL setting
FeatureConsumer *globalOut = NULL;
if (psd_classifier == VWLib)
globalOut = new VWLibraryTrainConsumer(output);
else
globalOut = new VWFileTrainConsumer(output);
cerr<< "Phrase tables read. Now reading in corpus." << endl;
while(true) {
if (psd.eof()) break;
if (++i % 100000 == 0) cerr << "." << flush;
char psdLine[LINE_MAX_LENGTH];
// get phrase pair
SAFE_GETLINE((psd),psdLine, LINE_MAX_LENGTH, '\n', __FILE__);
if (psd.eof()) break;
vector<string> token = Tokenize(psdLine,"\t");
//ycerr << "TOKENIZED: " << token.size() << " tokens in " << psdLine << endl;
// in test mode if no labels are provided
if (token.size() == 3){
psd_train_mode = false;
}else if (token.size() == 7 ){
psd_train_mode = true;
}else{
cerr << "Malformed psd entry: " << psdLine << endl;
exit(1);
}
int sid = Scan<int>(token[0].c_str());
int src_start = Scan<int>(token[1].c_str());
int src_end = Scan<int>(token[2].c_str());
int tgt_start = -1;
int tgt_end = -1;
if (psd_train_mode == true){
tgt_start = Scan<int>(token[3].c_str());
tgt_end = Scan<int>(token[4].c_str());
}
size_t sid = Scan<size_t>(token[0].c_str());
size_t src_start = Scan<size_t>(token[1].c_str());
size_t src_end = Scan<size_t>(token[2].c_str());
size_t tgt_start = Scan<size_t>(token[3].c_str());
size_t tgt_end = Scan<size_t>(token[4].c_str());
char rawSrcLine[LINE_MAX_LENGTH];
char tagSrcLine[LINE_MAX_LENGTH];
// go to current sentence
while(csid < sid){
if (src.eof()) break;
SAFE_GETLINE((src),rawSrcLine, LINE_MAX_LENGTH, '\n', __FILE__);
@ -174,6 +162,7 @@ int main(int argc,char* argv[]){
SAFE_GETLINE((srcTag),tagSrcLine, LINE_MAX_LENGTH, '\n', __FILE__);
++csid;
}
assert(csid == sid);
vector<string> sent = Tokenize(rawSrcLine);
string phrase;
@ -187,89 +176,40 @@ int main(int argc,char* argv[]){
}
}
PHRASE_ID srcid = getPhraseID(phrase,srcVocab,psdPhraseVoc);
PHRASE_ID srcid = getPhraseID(phrase, srcVocab, psdPhraseVoc);
cout << "PHRASE : " << srcid << " " << phrase << endl;
string escapedTagSrcLine = string(tagSrcLine);
if (psd_classifier == VW){
escapedTagSrcLine = escapeVwSpecialChars(string(tagSrcLine));
factorDelim = "_PIPE_";
}
ContextType factoredSrcLine = parseTaggedString(tagSrcLine, factorDelim);
PHRASE_ID labelid = -1;
if (psd_train_mode == true){
string tgtphrase = token[6];
labelid = getPhraseID(tgtphrase,tgtVocab,tgtPhraseVoc);
string tgtphrase = token[6];
PHRASE_ID labelid = getPhraseID(tgtphrase,tgtVocab,tgtPhraseVoc);
vector<float> losses;
vector<size_t> translations;
PhraseTranslations::const_iterator transIt;
for (transIt = transTable.lower_bound(srcid); transIt != transTable.upper_bound(srcid); transIt++) {
translations.push_back(transIt->second);
losses.push_back(labelid == transIt->second ? 0 : 1);
}
if (srcid != 0){
if ( psd_train_mode && exists(srcid,labelid,transTable) || !psd_train_mode && exists(srcid,transTable)){
FeatureConsumer *fc = NULL;
if (psd_classifier == RAW)
fc = new VWLibraryTrainConsumer();
else
fc = new VWFileTrainConsumer();
vector<string> features = fs.extract(src_start,src_end,escapedTagSrcLine,factorDelim);
if (exists(srcid, labelid, transTable)) {
if (psd_model == PHRASAL){
map<PHRASE_ID, ostream*>::iterator i = outFiles.find(srcid);
if (i == outFiles.end()){
map<PHRASE_ID, FeatureConsumer*>::iterator i = consumers.find(srcid);
if (i == consumers.end()){
int low = floor(srcid/subdirsize)*subdirsize;
int high = low+subdirsize-1;
string output_subdir = output +"/"+SPrint(low)+"-"+SPrint(high);
struct stat stat_buf;
mkdir(output_subdir.c_str(),S_IREAD | S_IWRITE | S_IEXEC);
string fileName = output_subdir + "/" + SPrint(srcid) + ext;
ofstream* testFile = new ofstream();
testFile->open(fileName.c_str());
outFiles.insert(pair<PHRASE_ID,ostream*>(srcid,testFile));
// also print out label dictionary
string labelName = output_subdir + "/" + SPrint(srcid) + labelext;
switch(psd_classifier) {
case VW:
if ( psd_train_mode && ! printVwHeaderFile(labelName,transTable,srcid,tgtPhraseVoc,tgtVocab)){
cerr << "ERROR: could not write file " << labelName << endl;
}
break;
case RAW:
break;
default:
if ( psd_train_mode && ! printTransStringToFile(labelName,transTable,srcid,tgtPhraseVoc,tgtVocab)){
cerr << "ERROR: could not write file " << labelName << endl;
}
}
FeatureConsumer *fc = NULL;
if (psd_classifier == VWLib)
fc = new VWLibraryTrainConsumer(fileName);
else
fc = new VWFileTrainConsumer(fileName);
consumers.insert(pair<PHRASE_ID,FeatureConsumer*>(srcid, fc));
}
int perPhraseLabelid = transTable[srcid][labelid];
string sf = "";
if (psd_train_mode){
sf = makeVwTrainingInstance(features,perPhraseLabelid);
}else{
sf = makeVwTestingInstance(features);
}
switch (psd_classifier) {
case VW:
//(*outFiles[srcid]) << perPhraseLabelid << " | " << sf << endl;
(*outFiles[srcid]) << sf << endl;
break;
default:
(*outFiles[srcid]) << src_start << "\t" << src_end << "\t" << perPhraseLabelid << "\t" << tagSrcLine << endl;
break;
}
}
if (psd_model == GLOBAL && psd_classifier == VW ){
set<PHRASE_ID> labels;
labels.insert(labelid); //TODO: collect all translation candidates with fractional counts!
string sfeatures = "";
if (psd_train_mode){
sfeatures = makeVwGlobalTrainingInstance(srcid,features,labels,transTable,tgtPhraseVoc,tgtVocab);
(*globalDataOut) << sfeatures << endl;
}else{
sfeatures = makeVwGlobalTestingInstance(srcid,features,transTable,tgtPhraseVoc,tgtVocab);
cout << "FEATURES " << sfeatures << endl;
(*globalDataOut) << sfeatures << endl;
}
}
if (psd_model == GLOBAL && psd_classifier == RAW ){
(*globalDataOut) << psdLine << "\t" << srcid << "\t" << labelid << "\t" << tagSrcLine << endl;
extractor.GenerateFeatures(consumers[srcid], factoredSrcLine, src_start, src_end, translations, losses);
} else { // GLOBAL model
extractor.GenerateFeatures(globalOut, factoredSrcLine, src_start, src_end, translations, losses);
}
}
}

View File

@ -1,97 +0,0 @@
#include <iostream>
#include <vector>
#include <string>
#include <cstring>
#include <sys/stat.h>
#include <sys/types.h>
#include "SafeGetline.h"
#include "InputFileStream.h"
#include "OutputFileStream.h"
#include "tables-core.h"
#include "Util.h"
#include "psd.h"
#include "PsdPhraseUtils.h"
using namespace std;
using namespace Moses;
#define LINE_MAX_LENGTH 10000
string factorDelim = "|";
MosesTraining::Vocabulary srcVocab;
MosesTraining::Vocabulary tgtVocab;
int main(int argc,char* argv[]){
if (argc < 7){
cerr << "Tag source phrases that are candidates for PSD disambiguation in test set\n\n";
cerr << "syntax: tag-psd-test corpus.raw base-phrase-table source-phrase-vocab tgt-phrase-vocab maxPhraseLength output-file-name [options]\n";
cerr << endl;
cerr << "Options:" << endl;
exit(1);
}
char* &fileNameSrcRaw = argv[1]; // raw source context
char* &fileNamePT = argv[2]; // phrase table
char* &fileNameSrcVoc = argv[3]; // source phrase vocabulary
char* &fileNameTgtVoc = argv[4];
int maxPhraseLength = Scan<int>(argv[5]);
char* &fileNameOut = argv[6]; // output file
// store word and phrase vocab and phrase table
PhraseVocab psdPhraseVoc;
if (!readPhraseVocab(fileNameSrcVoc,srcVocab,psdPhraseVoc)){
cerr << "Error reading in source phrase vocab" << endl;
exit(1);
}
PhraseVocab tgtPhraseVoc;
if (!readPhraseVocab(fileNameTgtVoc,tgtVocab,tgtPhraseVoc)){
cerr << "Error reading in target phrase vocab" << endl;
exit(1);
}
PhraseTranslations transTable;
if (!readPhraseTranslations(fileNamePT, srcVocab, tgtVocab, psdPhraseVoc, tgtPhraseVoc, transTable)){
cerr << "Error reading in phrase translation table " << endl;
exit(1);
}
// we will print to psdOut
ofstream psdOut;
psdOut.open(fileNameOut);
if ( !psdOut ){
cerr << "Error opening " << fileNameOut << endl;
exit(1);
}
// read in corpus
InputFileStream src(fileNameSrcRaw);
if (src.fail()){
cerr << "ERROR: could not open " << fileNameSrcRaw << endl;
exit(1);
}
int sid = 0;
while(true){
if (src.eof()) break;
char srcLine[LINE_MAX_LENGTH];
SAFE_GETLINE((src),srcLine,LINE_MAX_LENGTH, '\n', __FILE__);
sid++;
if (src.eof()) break;
vector<string> sentence = Tokenize(srcLine);
for(int s = 0; s < sentence.size(); s++){
string phrase = "";
for(int e = s; e < s+maxPhraseLength && e < sentence.size(); e++){
if (phrase != "") phrase += " ";
phrase += sentence[e];
MosesTraining::PHRASE_ID pid = getPhraseID(phrase, srcVocab, psdPhraseVoc);
if (pid == 0) break;
PhraseTranslations::iterator itr = transTable.find(pid);
if (itr == transTable.end()) break;
psdOut << sid << "\t" << s << "\t" << e << endl;
}
}
}
}

View File

@ -1,196 +0,0 @@
#include <iostream>
#include <vector>
#include <string>
#include <cstring>
#include <sys/stat.h>
#include <sys/types.h>
#include "SafeGetline.h"
#include "InputFileStream.h"
#include "OutputFileStream.h"
#include "tables-core.h"
#include "Util.h"
#include "psd.h"
#include "PsdPhraseUtils.h"
using namespace std;
using namespace Moses;
#define LINE_MAX_LENGTH 10000
// globals
CLASSIFIER_TYPE psd_classifier = RAW;
PSD_MODEL_TYPE psd_model = GLOBAL;
string ptDelim = " ||| ";
string factorDelim = "|";
int subdirsize=1000;
string labelext = ".labels";
string predext = ".vw.pred";
MosesTraining::Vocabulary srcVocab;
MosesTraining::Vocabulary tgtVocab;
string output_dir="psd";
int main(int argc,char* argv[]){
cerr << "constructor for phrase-table augmented with context-dependent PSD scores.\n\n";
if (argc < 8){
cerr << "syntax: make-psd-table path-to-psd-predictions corpus.psd corpus.raw base-phrase-table sourcePhraseVocab targetPhraseVocab [options]\n";
cerr << endl;
cerr << "Options:" << endl;
cerr << "\t --ClassifierType vw|megam|none" << endl;
cerr << "\t --PsdType phrasal|global" << endl;
exit(1);
}
char* &fileNamePred = argv[1]; // path to PSD predictions
char* &fileNamePsd = argv[2]; // location in corpus of PSD phrases (optionallly annotated with the position and phrase type of their translations.)
char* &fileNameSrcRaw = argv[3]; // raw source context
char* &fileNamePT = argv[4]; // phrase table
char* &fileNameSrcVoc = argv[5]; // source phrase vocabulary
char* &fileNameTgtVoc = argv[6]; // target phrase vocabulary
char* &fileNameOutput = argv[7]; //root name for the integerized output corpus and associated phrase-table
for(int i = 8; i < argc; i++){
if (strcmp(argv[i],"--ClassifierType") == 0){
char* format = argv[++i];
if (strcmp(format,"vw") == 0){
psd_classifier = VW;
predext = ".vw.pred";
}else if (strcmp(format,"none") == 0){
psd_classifier = RAW;
}else if (strcmp(format,"megam") == 0){
psd_classifier = MEGAM;
predext = ".megam";
cerr << "megam format isn't supported" << endl;
exit(1);
}else{
cerr << "classifier " << format << "isn't supported" << endl;
exit(1);
}
}
if (strcmp(argv[i],"--PsdType") == 0){
char* format = argv[++i];
if (strcmp(format,"global") == 0){
psd_model = GLOBAL;
}else if (strcmp(format,"phrasal") == 0){
psd_model = PHRASAL;
}else{
cerr << "PSD model type " << format << "isn't supported" << endl;
exit(1);
}
}
}
string corpus = string(fileNameOutput) + ".corpus";
ofstream outCorpus(corpus.c_str());
string pt = string(fileNameOutput) + ".phrase-table";
ofstream outPT(pt.c_str());
InputFileStream src(fileNameSrcRaw);
if (src.fail()){
cerr << "ERROR: could not open " << fileNameSrcRaw << endl;
}
InputFileStream psd(fileNamePsd);
if (psd.fail()){
cerr << "ERROR: could not open " << fileNamePsd << endl;
}
// store word and phrase vocab
PhraseVocab psdPhraseVoc;
if (!readPhraseVocab(fileNameSrcVoc,srcVocab,psdPhraseVoc)){
cerr << "Error reading in source phrase vocab" << endl;
}
PhraseVocab tgtPhraseVoc;
if (!readPhraseVocab(fileNameTgtVoc,tgtVocab,tgtPhraseVoc)){
cerr << "Error reading in target phrase vocab" << endl;
}
// store baseline phrase-table
PhraseTranslations transTable;
map<string,string> transTableScores;
if (!readPhraseTranslations(fileNamePT, srcVocab, tgtVocab, psdPhraseVoc, tgtPhraseVoc, transTable, transTableScores)){
// if (!readPhraseTranslations(fileNamePT, srcVocab, tgtVocab, psdPhraseVoc, tgtPhraseVoc, transTable)){
cerr << "Error reading in phrase translation table " << endl;
}
// get ready to read in VW predictions
map<MosesTraining::PHRASE_ID,InputFileStream*> vwPredFiles; //for phrasal model
InputFileStream* vwPredFile; //for global model
// go through tagged PSD examples in the order they occur in the test corpus
int i = 0;
int csid = 0;
map<MosesTraining::PHRASE_ID, istream*> predFiles;
int toks_covered = -1; //last token position covered in test corpus
cerr<< "Phrase tables read. Now reading in corpus." << endl;
while(true) {
if (psd.eof()) break;
if (++i % 100000 == 0) cerr << "." << flush;
char psdLine[LINE_MAX_LENGTH];
SAFE_GETLINE((psd),psdLine, LINE_MAX_LENGTH, '\n', __FILE__);
if (psd.eof()) break;
vector<string> token = Tokenize(psdLine);
assert(token.size() > 2);
int sid = Scan<int>(token[0].c_str());
int src_start = Scan<int>(token[1].c_str());
int src_end = Scan<int>(token[2].c_str());
// int tgt_start = Scan<int>(token[3].c_str());
// int tgt_end = Scan<int>(token[4].c_str());
char rawSrcLine[LINE_MAX_LENGTH];
while(csid < sid){
if (src.eof()) break;
SAFE_GETLINE((src),rawSrcLine, LINE_MAX_LENGTH, '\n', __FILE__);
vector<string> sent = Tokenize(rawSrcLine);
// print integerized test set sentence
string isent;
for(int j = 0; j < sent.size(); j++){
if (isent != "") isent+=" ";
isent += SPrint(toks_covered+j+1);
}
outCorpus << isent << endl;
toks_covered += sent.size();
++csid;
}
assert(csid == sid);
vector<string> sent = Tokenize(rawSrcLine);
string phrase = sent[src_start];
assert(src_end < sent.size());
for(size_t j = src_start + 1; j < src_end + 1; j++){
phrase = phrase+ " " + sent[j];
}
MosesTraining::PHRASE_ID srcid = getPhraseID(phrase,srcVocab,psdPhraseVoc);
if (srcid != 0){
// make integerized source id
string src2int = SPrint(toks_covered-sent.size()+src_start+1);
for(int i = src_start+1; i < src_end + 1; ++i){
src2int = src2int + " " + SPrint(toks_covered-sent.size()+i+1);
}
// for now don't get PSD prediction, only integerize PT without adding context-dependent score
if (psd_classifier != RAW){
cerr << "classifier type not supported yet" << endl;
exit(1);
}
// find candidate translations
PhraseTranslations::iterator itr = transTable.find(srcid);
if ( itr != transTable.end()){
for(map<MosesTraining::PHRASE_ID,int>::iterator itr2 = itr->second.begin(); itr2 != itr->second.end(); itr2++){
MosesTraining::PHRASE_ID labelid = itr2->first;
string label = getPhrase(labelid, tgtVocab, tgtPhraseVoc);
string sl_key = SPrint(srcid) + " " + SPrint(labelid);
map<string,string>::iterator itr3 = transTableScores.find(sl_key);
if (itr3 != transTableScores.end()){
// print integerized phrase-table
outPT << src2int << " ||| " << label << " ||| " << itr3->second << endl;
}
}
}
}
}
outPT.close();
outCorpus.close();
}

View File

@ -7,11 +7,14 @@
#include <sstream>
#include <deque>
#ifdef HAVE_VW
namespace PSD
{
// #ifdef HAVE_VW
// forward declarations to avoid dependency on VW
struct vw;
class ezexample;
#endif
// #endif
// abstract consumer
class FeatureConsumer
@ -49,7 +52,7 @@ private:
std::string EscapeSpecialChars(const std::string &str);
};
#ifdef HAVE_VW
// #ifdef HAVE_VW
// abstract consumer that trains/predicts using VW library interface
class VWLibraryConsumer : public FeatureConsumer
{
@ -86,6 +89,8 @@ private:
virtual void Train(const std::string &label, float loss);
virtual float Predict(const std::string &label);
};
#endif // HAVE_VW
// #endif // HAVE_VW
} // namespace PSD
#endif // moses_FeatureConsumer_h

View File

@ -5,46 +5,49 @@ using namespace std;
using namespace boost::bimaps;
using namespace Moses;
namespace PSD
{
FeatureExtractor::FeatureExtractor(FeatureTypes ft,
FeatureConsumer *fc,
const TargetIndexType &targetIndex,
bool train)
: m_ft(ft), m_fc(fc), m_targetIndex(targetIndex), m_train(train)
: m_targetIndex(targetIndex), m_train(train)
{
}
void FeatureExtractor::GenerateFeatures(const ContextType &context,
void FeatureExtractor::GenerateFeatures(FeatureConsumer *fc,
const ContextType &context,
size_t spanStart,
size_t spanEnd,
const vector<size_t> &translations,
vector<float> &losses)
{
m_fc->SetNamespace('s', true);
if (m_ft.m_sourceExternal) {
GenerateContextFeatures(context, spanStart, spanEnd);
fc->SetNamespace('s', true);
if (PSD_SOURCE_EXTERNAL) {
GenerateContextFeatures(context, spanStart, spanEnd, fc);
}
if (m_ft.m_sourceInternal) {
if (PSD_SOURCE_INTERNAL) {
vector<string> sourceForms(spanEnd - spanStart + 1);
for (size_t i = spanStart; i <= spanEnd; i++) {
sourceForms[i] = context[i][0]; // XXX assumes that form is the 0th factor
}
GenerateInternalFeatures(sourceForms);
GenerateInternalFeatures(sourceForms, fc);
}
vector<size_t>::const_iterator transIt = translations.begin();
vector<float>::iterator lossIt = losses.begin();
for (; transIt != translations.end(); transIt++, lossIt++) {
assert(lossIt != losses.end());
m_fc->SetNamespace('t', false);
if (m_ft.m_targetInternal) {
GenerateInternalFeatures(Tokenize(" ", m_targetIndex.right.find(*transIt)->second));
fc->SetNamespace('t', false);
if (PSD_TARGET_INTERNAL) {
GenerateInternalFeatures(Tokenize(" ", m_targetIndex.right.find(*transIt)->second), fc);
}
if (m_train) {
m_fc->Train(SPrint(*transIt), *lossIt);
fc->Train(SPrint(*transIt), *lossIt);
} else {
*lossIt = m_fc->Predict(SPrint(*transIt));
*lossIt = fc->Predict(SPrint(*transIt));
}
}
}
@ -60,25 +63,26 @@ string FeatureExtractor::BuildContextFeature(size_t factor, int index, const str
void FeatureExtractor::GenerateContextFeatures(const ContextType &context,
size_t spanStart,
size_t spanEnd)
size_t spanEnd,
FeatureConsumer *fc)
{
vector<size_t>::const_iterator factorIt;
for (factorIt = m_ft.m_factors.begin(); factorIt != m_ft.m_factors.end(); factorIt++) {
for (size_t i = 1; i <= m_ft.m_contextWindow; i++) {
for (size_t fact = 0; fact <= PSD_FACTOR_COUNT; fact++) {
for (size_t i = 1; i <= PSD_CONTEXT_WINDOW; i++) {
if (spanStart >= i)
m_fc->AddFeature(BuildContextFeature(*factorIt, i, context[spanstart - i][*factorit]);
fc->AddFeature(BuildContextFeature(fact, i, context[spanStart - i][fact]));
if (spanEnd + i < context.size())
m_fc->AddFeature(BuildContextFeature(*factorIt, i, context[spanstart - i][*factorit]);
fc->AddFeature(BuildContextFeature(fact, i, context[spanStart - i][fact]));
}
}
}
void FeatureExtractor::GenerateInternalFeatures(const vector<string> &span)
void FeatureExtractor::GenerateInternalFeatures(const vector<string> &span, FeatureConsumer *fc)
{
m_fc->AddFeature("p^" + Join("_", span));
fc->AddFeature("p^" + Join("_", span));
vector<string>::const_iterator it;
for (it = span.begin(); it != span.end(); it++) {
m_fc->AddFeature("w^" + *it);
fc->AddFeature("w^" + *it);
}
}
} // namespace PSD

View File

@ -8,51 +8,48 @@
#include <map>
#include <boost/bimap/bimap.hpp>
namespace PSD
{
// vector of words, each word is a vector of factors
typedef std::vector<std::vector<std::string> > ContextType;
// index of possible target spans
typedef boost::bimaps::bimap<std::string, size_t> TargetIndexType;
// configuration of feature extractor
struct FeatureTypes
{
bool m_sourceExternal; // generate context features
bool m_sourceInternal; // generate source-side phrase-internal features
bool m_targetInternal; // generate target-side phrase-internal features
bool m_paired; // generate paired features
bool m_bagOfWords; // generate bag-of-words features
// configuration of feature extraction, shared, global
const bool PSD_SOURCE_EXTERNAL = true; // generate context features
const bool PSD_SOURCE_INTERNAL = true; // generate source-side phrase-internal features
const bool PSD_TARGET_INTERNAL = true; // generate target-side phrase-internal features
const bool PSD_PAIRED = false; // generate paired features
const bool PSD_BAG_OF_wORDS = false; // generate bag-of-words features
size_t m_contextWindow; // window size for context features
const size_t PSD_CONTEXT_WINDOW = 2; // window size for context features
// list of factors that should be extracted from context (e.g. 0,1,2)
std::vector<size_t> m_factors;
};
const size_t[] PSD_FACTORS = { 0, 1, 2 };
const size_t PSD_FACTOR_COUNT = 3;
// extract features
class FeatureExtractor
{
public:
FeatureExtractor(FeatureTypes ft,
FeatureConsumer *fc,
const TargetIndexType &targetIndex,
bool train);
FeatureExtractor(const TargetIndexType &targetIndex, bool train);
void GenerateFeatures(const ContextType &context,
void GenerateFeatures(FeatureConsumer *fc,
const ContextType &context,
size_t spanStart,
size_t spanEnd,
const std::vector<size_t> &translations,
std::vector<float> &losses);
private:
FeatureTypes m_ft;
FeatureConsumer *m_fc;
const TargetIndexType &m_targetIndex;
bool m_train;
void GenerateContextFeatures(const ContextType &context, size_t spanStart, size_t spanEnd);
void GenerateInternalFeatures(const std::vector<std::string> &span);
void GenerateContextFeatures(const ContextType &context, size_t spanStart, size_t spanEnd, FeatureConsumer *fc);
void GenerateInternalFeatures(const std::vector<std::string> &span, FeatureConsumer *fc);
std::string BuildContextFeature(size_t factor, int index, const std::string &value);
};
} // namespace PSD
#endif // moses_FeatureExtractor_h

17
psd/Jamfile Normal file
View File

@ -0,0 +1,17 @@
alias headers : : : : <include>. <include>../moses/src <include>.. ;
boost 103600 ;
# VW
local with-vw = [ option.get "with-vw" ] ;
if $(with-vw) {
lib vw : : <search>$(with-vw)/lib ;
lib allreduce : : <search>$(with-vw)/lib ;
obj VWLibraryConsumer.o : VWLibraryConsumer.cpp headers : <include>$(with-vw)/library <include>$(with-vw)/vowpalwabbit <define>HAVE_VW ;
alias vw_objects : VWLibraryConsumer.o vw allreduce : : : <library>boost_program_options ;
echo "Linking with Vowpal Wabbit" ;
} else {
alias vw_objects ;
}
lib psd : [ glob *.cpp : VWLibraryConsumer.cpp ] vw_objects headers ;

View File

@ -8,6 +8,9 @@
using namespace std;
using namespace Moses;
namespace PSD
{
VWFileTrainConsumer::VWFileTrainConsumer(const std::string &outputFile)
{
m_os.open(outputFile.c_str());
@ -68,7 +71,7 @@ void VWFileTrainConsumer::WriteBuffer()
}
std::string VWFileTrainConsumer::EscapeSpecialChars(const std::string &str);
std::string VWFileTrainConsumer::EscapeSpecialChars(const std::string &str)
{
string out;
out = Replace(str, "|", "_PIPE_");
@ -76,3 +79,5 @@ std::string VWFileTrainConsumer::EscapeSpecialChars(const std::string &str);
out = Replace(out, " ", "_");
return out;
}
} // namespace PSD

View File

@ -8,6 +8,9 @@
using namespace std;
namespace PSD
{
//
// VWLibraryConsumer
//
@ -94,3 +97,4 @@ float VWLibraryPredictConsumer::Predict(const string &label)
return m_ex->predict();
}
} // namespace PSD