diff --git a/Jamroot b/Jamroot index 1f7ca48cd..0a146c528 100644 --- a/Jamroot +++ b/Jamroot @@ -114,10 +114,24 @@ requirements += [ option.get "with-mm" : : PT_UG ] ; requirements += [ option.get "with-mm" : : MAX_NUM_FACTORS=4 ] ; requirements += [ option.get "unlabelled-source" : : UNLABELLED_SOURCE ] ; +if [ option.get "with-lbllm" ] { + external-lib boost_serialization ; + external-lib gomp ; + requirements += boost_serialization ; + requirements += gomp ; +} + if [ option.get "with-cmph" ] { requirements += HAVE_CMPH ; } +if [ option.get "with-probing-pt" : : "yes" ] +{ + external-lib boost_serialization ; + requirements += HAVE_PROBINGPT ; + requirements += boost_serialization ; +} + project : default-build multi on @@ -145,6 +159,7 @@ build-projects lm util phrase-extract search moses moses/LM mert moses-cmd moses if [ option.get "with-mm" : : "yes" ] { alias mm : + moses/TranslationModel/UG//ptable-lookup moses/TranslationModel/UG/mm//mtt-build moses/TranslationModel/UG/mm//mtt-dump moses/TranslationModel/UG/mm//symal2mam diff --git a/OnDiskPt/Main.cpp b/OnDiskPt/Main.cpp index f2d75ed05..01a8ce1ba 100644 --- a/OnDiskPt/Main.cpp +++ b/OnDiskPt/Main.cpp @@ -66,10 +66,9 @@ int main (int argc, char * const argv[]) PhraseNode &rootNode = onDiskWrapper.GetRootSourceNode(); size_t lineNum = 0; - char line[100000]; + string line; - //while(getline(inStream, line)) - while(inStream.getline(line, 100000)) { + while(getline(inStream, line)) { lineNum++; if (lineNum%1000 == 0) cerr << "." << flush; if (lineNum%10000 == 0) cerr << ":" << flush; @@ -107,8 +106,13 @@ bool Flush(const OnDiskPt::SourcePhrase *prevSourcePhrase, const OnDiskPt::Sourc return ret; } -OnDiskPt::PhrasePtr Tokenize(SourcePhrase &sourcePhrase, TargetPhrase &targetPhrase, char *line, OnDiskWrapper &onDiskWrapper, int numScores, vector &misc) +OnDiskPt::PhrasePtr Tokenize(SourcePhrase &sourcePhrase, TargetPhrase &targetPhrase, const std::string &lineStr, OnDiskWrapper &onDiskWrapper, int numScores, vector &misc) { + char line[lineStr.size() + 1]; + strcpy(line, lineStr.c_str()); + + stringstream sparseFeatures, property; + size_t scoreInd = 0; // MAIN LOOP @@ -118,6 +122,7 @@ OnDiskPt::PhrasePtr Tokenize(SourcePhrase &sourcePhrase, TargetPhrase &targetPhr 2 = scores 3 = align 4 = count + 7 = properties */ char *tok = strtok (line," "); OnDiskPt::PhrasePtr out(new Phrase()); @@ -148,29 +153,20 @@ OnDiskPt::PhrasePtr Tokenize(SourcePhrase &sourcePhrase, TargetPhrase &targetPhr targetPhrase.CreateAlignFromString(tok); break; } - case 4: - ++stage; - break; - /* case 5: { - // count info. Only store the 2nd one - float val = Moses::Scan(tok); - misc[0] = val; - ++stage; - break; - }*/ + case 4: { + // store only the 3rd one (rule count) + float val = Moses::Scan(tok); + misc[0] = val; + break; + } case 5: { - // count info. Only store the 2nd one - //float val = Moses::Scan(tok); - //misc[0] = val; - ++stage; + // sparse features + sparseFeatures << tok << " "; break; } case 6: { - // store only the 3rd one (rule count) - float val = Moses::Scan(tok); - misc[0] = val; - ++stage; - break; + property << tok << " "; + break; } default: cerr << "ERROR in line " << line << endl; @@ -183,6 +179,8 @@ OnDiskPt::PhrasePtr Tokenize(SourcePhrase &sourcePhrase, TargetPhrase &targetPhr } // while (tok != NULL) assert(scoreInd == numScores); + targetPhrase.SetSparseFeatures(Moses::Trim(sparseFeatures.str())); + targetPhrase.SetProperty(Moses::Trim(property.str())); targetPhrase.SortAlign(); return out; } // Tokenize() diff --git a/OnDiskPt/Main.h b/OnDiskPt/Main.h index 2b2d585d8..fcdb2cd9d 100644 --- a/OnDiskPt/Main.h +++ b/OnDiskPt/Main.h @@ -29,7 +29,7 @@ OnDiskPt::WordPtr Tokenize(OnDiskPt::Phrase &phrase , const std::string &token, bool addSourceNonTerm, bool addTargetNonTerm , OnDiskPt::OnDiskWrapper &onDiskWrapper, int retSourceTarget); OnDiskPt::PhrasePtr Tokenize(OnDiskPt::SourcePhrase &sourcePhrase, OnDiskPt::TargetPhrase &targetPhrase - , char *line, OnDiskPt::OnDiskWrapper &onDiskWrapper + , const std::string &lineStr, OnDiskPt::OnDiskWrapper &onDiskWrapper , int numScores , std::vector &misc); diff --git a/OnDiskPt/OnDiskWrapper.cpp b/OnDiskPt/OnDiskWrapper.cpp index 0120802ac..26b83eee3 100644 --- a/OnDiskPt/OnDiskWrapper.cpp +++ b/OnDiskPt/OnDiskWrapper.cpp @@ -31,7 +31,7 @@ using namespace std; namespace OnDiskPt { -int OnDiskWrapper::VERSION_NUM = 5; +int OnDiskWrapper::VERSION_NUM = 7; OnDiskWrapper::OnDiskWrapper() { diff --git a/OnDiskPt/TargetPhrase.cpp b/OnDiskPt/TargetPhrase.cpp index cb821a557..39f425b95 100644 --- a/OnDiskPt/TargetPhrase.cpp +++ b/OnDiskPt/TargetPhrase.cpp @@ -162,10 +162,14 @@ char *TargetPhrase::WriteOtherInfoToMemory(OnDiskWrapper &onDiskWrapper, size_t // allocate mem size_t numScores = onDiskWrapper.GetNumScores() ,numAlign = GetAlign().size(); + size_t sparseFeatureSize = m_sparseFeatures.size(); + size_t propSize = m_property.size(); - size_t memNeeded = sizeof(UINT64); // file pos (phrase id) - memNeeded += sizeof(UINT64) + 2 * sizeof(UINT64) * numAlign; // align - memNeeded += sizeof(float) * numScores; // scores + size_t memNeeded = sizeof(UINT64) // file pos (phrase id) + + sizeof(UINT64) + 2 * sizeof(UINT64) * numAlign // align + + sizeof(float) * numScores // scores + + sizeof(UINT64) + sparseFeatureSize // sparse features string + + sizeof(UINT64) + propSize; // property string char *mem = (char*) malloc(memNeeded); //memset(mem, 0, memNeeded); @@ -183,11 +187,33 @@ char *TargetPhrase::WriteOtherInfoToMemory(OnDiskWrapper &onDiskWrapper, size_t // scores memUsed += WriteScoresToMemory(mem + memUsed); + // sparse features + memUsed += WriteStringToMemory(mem + memUsed, m_sparseFeatures); + + // property string + memUsed += WriteStringToMemory(mem + memUsed, m_property); + //DebugMem(mem, memNeeded); assert(memNeeded == memUsed); return mem; } +size_t TargetPhrase::WriteStringToMemory(char *mem, const std::string &str) const +{ + size_t memUsed = 0; + UINT64 *memTmp = (UINT64*) mem; + + size_t strSize = str.size(); + memTmp[0] = strSize; + memUsed += sizeof(UINT64); + + const char *charStr = str.c_str(); + memcpy(mem + memUsed, charStr, strSize); + memUsed += strSize; + + return memUsed; +} + size_t TargetPhrase::WriteAlignToMemory(char *mem) const { size_t memUsed = 0; @@ -279,6 +305,13 @@ Moses::TargetPhrase *TargetPhrase::ConvertToMoses(const std::vectorGetScoreBreakdown().Assign(&phraseDict, m_scores); + + // sparse features + ret->GetScoreBreakdown().Assign(&phraseDict, m_sparseFeatures); + + // property + ret->SetProperties(m_property); + ret->Evaluate(mosesSP, phraseDict.GetFeaturesToApply()); return ret; @@ -299,9 +332,36 @@ UINT64 TargetPhrase::ReadOtherInfoFromFile(UINT64 filePos, std::fstream &fileTPC memUsed += ReadScoresFromFile(fileTPColl); assert((memUsed + filePos) == (UINT64)fileTPColl.tellg()); + // sparse features + memUsed += ReadStringFromFile(fileTPColl, m_sparseFeatures); + + // properties + memUsed += ReadStringFromFile(fileTPColl, m_property); + return memUsed; } +UINT64 TargetPhrase::ReadStringFromFile(std::fstream &fileTPColl, std::string &outStr) +{ + UINT64 bytesRead = 0; + + UINT64 strSize; + fileTPColl.read((char*) &strSize, sizeof(UINT64)); + bytesRead += sizeof(UINT64); + + if (strSize) { + char *mem = (char*) malloc(strSize + 1); + mem[strSize] = '\0'; + fileTPColl.read(mem, strSize); + outStr = string(mem); + free(mem); + + bytesRead += strSize; + } + + return bytesRead; +} + UINT64 TargetPhrase::ReadFromFile(std::fstream &fileTP) { UINT64 bytesRead = 0; diff --git a/OnDiskPt/TargetPhrase.h b/OnDiskPt/TargetPhrase.h index 5b8a30296..89b7f967e 100644 --- a/OnDiskPt/TargetPhrase.h +++ b/OnDiskPt/TargetPhrase.h @@ -50,15 +50,18 @@ class TargetPhrase: public Phrase protected: AlignType m_align; PhrasePtr m_sourcePhrase; + std::string m_sparseFeatures, m_property; std::vector m_scores; UINT64 m_filePos; size_t WriteAlignToMemory(char *mem) const; size_t WriteScoresToMemory(char *mem) const; + size_t WriteStringToMemory(char *mem, const std::string &str) const; UINT64 ReadAlignFromFile(std::fstream &fileTPColl); UINT64 ReadScoresFromFile(std::fstream &fileTPColl); + UINT64 ReadStringFromFile(std::fstream &fileTPColl, std::string &outStr); public: TargetPhrase() { @@ -110,6 +113,15 @@ public: virtual void DebugPrint(std::ostream &out, const Vocab &vocab) const; + void SetProperty(const std::string &value) + { + m_property = value; + } + + void SetSparseFeatures(const std::string &value) + { + m_sparseFeatures = value; + } }; } diff --git a/OnDiskPt/Word.cpp b/OnDiskPt/Word.cpp index 23d29cc7a..33bdb6cc5 100644 --- a/OnDiskPt/Word.cpp +++ b/OnDiskPt/Word.cpp @@ -104,14 +104,20 @@ void Word::ConvertToMoses( Moses::FactorCollection &factorColl = Moses::FactorCollection::Instance(); overwrite = Moses::Word(m_isNonTerminal); - // TODO: this conversion should have been done at load time. - util::TokenIter tok(vocab.GetString(m_vocabId), '|'); - - for (std::vector::const_iterator t = outputFactorsVec.begin(); t != outputFactorsVec.end(); ++t, ++tok) { - UTIL_THROW_IF2(!tok, "Too few factors in \"" << vocab.GetString(m_vocabId) << "\"; was expecting " << outputFactorsVec.size()); - overwrite.SetFactor(*t, factorColl.AddFactor(*tok, m_isNonTerminal)); + if (m_isNonTerminal) { + const std::string &tok = vocab.GetString(m_vocabId); + overwrite.SetFactor(0, factorColl.AddFactor(tok, m_isNonTerminal)); + } + else { + // TODO: this conversion should have been done at load time. + util::TokenIter tok(vocab.GetString(m_vocabId), '|'); + + for (std::vector::const_iterator t = outputFactorsVec.begin(); t != outputFactorsVec.end(); ++t, ++tok) { + UTIL_THROW_IF2(!tok, "Too few factors in \"" << vocab.GetString(m_vocabId) << "\"; was expecting " << outputFactorsVec.size()); + overwrite.SetFactor(*t, factorColl.AddFactor(*tok, m_isNonTerminal)); + } + UTIL_THROW_IF2(tok, "Too many factors in \"" << vocab.GetString(m_vocabId) << "\"; was expecting " << outputFactorsVec.size()); } - UTIL_THROW_IF2(tok, "Too many factors in \"" << vocab.GetString(m_vocabId) << "\"; was expecting " << outputFactorsVec.size()); } int Word::Compare(const Word &compare) const diff --git a/contrib/moses-speedtest/README.md b/contrib/moses-speedtest/README.md new file mode 100644 index 000000000..c95c6a400 --- /dev/null +++ b/contrib/moses-speedtest/README.md @@ -0,0 +1,122 @@ +# Moses speedtesting framework + +### Description + +This is an automatic test framework that is designed to test the day to day performance changes in Moses. + +### Set up + +#### Set up a Moses repo +Set up a Moses repo and build it with the desired configuration. +```bash +git clone https://github.com/moses-smt/mosesdecoder.git +cd mosesdecoder +./bjam -j10 --with-cmph=/usr/include/ +``` +You need to build Moses first, so that the testsuite knows what command you want it to use when rebuilding against newer revisions. + +#### Create a parent directory. +Create a parent directory where the **runtests.py** and related scripts and configuration file should reside. +This should also be the location of the TEST_DIR and TEST_LOG_DIR as explained in the next section. + +#### Set up a global configuration file. +You need a configuration file for the testsuite. A sample configuration file is provided in **testsuite\_config** +
+MOSES_REPO_PATH: /home/moses-speedtest/moses-standard/mosesdecoder
+DROP_CACHES_COMM: sys_drop_caches 3
+TEST_DIR: /home/moses-speedtest/phrase_tables/tests
+TEST_LOG_DIR: /home/moses-speedtest/phrase_tables/testlogs
+BASEBRANCH: RELEASE-2.1.1
+
+ +The _MOSES\_REPO\_PATH_ is the place where you have set up and built moses. +The _DROP\_CACHES\_COMM_ is the command that would beused to drop caches. It should run without needing root access. +_TEST\_DIR_ is the directory where all the tests will reside. +_TEST\_LOG\_DIR_ is the directory where the performance logs will be gathered. It should be created before running the testsuite for the first time. +_BASEBRANCH_ is the branch against which all new tests will be compared. It should normally be set to be the latest Moses stable release. + +### Creating tests + +In order to create a test one should go into the TEST_DIR and create a new folder. That folder will be used for the name of the test. +Inside that folder one should place a configuration file named **config**. The naming is mandatory. +An example such configuration file is **test\_config** + +
+Command: moses -f ... -i fff #Looks for the command in the /bin directory of the repo specified in the testsuite_config
+LDPRE: ldpreloads #Comma separated LD_LIBRARY_PATH:/, 
+Variants: vanilla, cached, ldpre #Can't have cached without ldpre or vanilla
+
+ +The _Command:_ line specifies the executable (which is looked up in the /bin directory of the repo.) and any arguments necessary. Before running the test, the script cds to the current test directory so you can use relative paths. +The _LDPRE:_ specifies if tests should be run with any LD\_PRELOAD flags. +The _Variants:_ line specifies what type of tests should we run. This particular line will run the following tests: +1. A Vanilla test meaning just the command after _Command_ will be issued. +2. A vanilla cached test meaning that after the vanilla test, the test will be run again without dropping caches in order to benchmark performance on cached filesystem. +3. A test with LD_PRELOAD ldpreloads moses -f command. For each available LDPRELOAD comma separated library to preload. +4. A cached version of all LD_PRELOAD tests. + +### Running tests. +Running the tests is done through the **runtests.py** script. + +#### Running all tests. +To run all tests, with the base branch and the latests revision (and generate new basebranch test data if such is missing) do a: +```bash +python3 runtests.py -c testsuite_config +``` + +#### Running specific tests. +The script allows the user to manually run a particular test or to test against a specific branch or revision: +
+moses-speedtest@crom:~/phrase_tables$ python3 runtests.py --help
+usage: runtests.py [-h] -c CONFIGFILE [-s SINGLETESTDIR] [-r REVISION]
+                   [-b BRANCH]
+
+A python based speedtest suite for moses.
+
+optional arguments:
+  -h, --help            show this help message and exit
+  -c CONFIGFILE, --configfile CONFIGFILE
+                        Specify test config file
+  -s SINGLETESTDIR, --singletest SINGLETESTDIR
+                        Single test name directory. Specify directory name,
+                        not full path!
+  -r REVISION, --revision REVISION
+                        Specify a specific revison for the test.
+  -b BRANCH, --branch BRANCH
+                        Specify a branch for the test.
+
+ +### Generating HTML report. +To generate a summary of the test results use the **html\_gen.py** script. It places a file named *index.html* in the current script directory. +```bash +python3 html_gen.py testsuite_config +``` +You should use the generated file with the **style.css** file provided in the html directory. + +### Command line regression testing. +Alternatively you could check for regressions from the command line using the **check\_fo\r_regression.py** script: +```bash +python3 check_for_regression.py TESTLOGS_DIRECTORY +``` + +Alternatively the results of all tests are logged inside the the specified TESTLOGS directory so you can manually check them for additional information such as date, time, revision, branch, etc... + +### Create a cron job: +Create a cron job to run the tests daily and generate an html report. An example *cronjob* is available. +```bash +#!/bin/sh +cd /home/moses-speedtest/phrase_tables + +python3 runtests.py -c testsuite_config #Run the tests. +python3 html_gen.py testsuite_config #Generate html + +cp index.html /fs/thor4/html/www/speed-test/ #Update the html +``` + +Place the script in _/etc/cron.daily_ for dayly testing + +###### Author +Nikolay Bogoychev, 2014 + +###### License +This software is licensed under the LGPL. \ No newline at end of file diff --git a/contrib/moses-speedtest/check_for_regression.py b/contrib/moses-speedtest/check_for_regression.py new file mode 100644 index 000000000..1e269c0c6 --- /dev/null +++ b/contrib/moses-speedtest/check_for_regression.py @@ -0,0 +1,63 @@ +"""Checks if any of the latests tests has performed considerably different than + the previous ones. Takes the log directory as an argument.""" +import os +import sys +from testsuite_common import Result, processLogLine, bcolors, getLastTwoLines + +LOGDIR = sys.argv[1] #Get the log directory as an argument +PERCENTAGE = 5 #Default value for how much a test shoudl change +if len(sys.argv) == 3: + PERCENTAGE = float(sys.argv[2]) #Default is 5%, but we can specify more + #line parameter + +def printResults(regressed, better, unchanged, firsttime): + """Pretty print the results in different colours""" + if regressed != []: + for item in regressed: + print(bcolors.RED + "REGRESSION! " + item.testname + " Was: "\ + + str(item.previous) + " Is: " + str(item.current) + " Change: "\ + + str(abs(item.percentage)) + "%. Revision: " + item.revision\ + + bcolors.ENDC) + print('\n') + if unchanged != []: + for item in unchanged: + print(bcolors.BLUE + "UNCHANGED: " + item.testname + " Revision: " +\ + item.revision + bcolors.ENDC) + print('\n') + if better != []: + for item in better: + print(bcolors.GREEN + "IMPROVEMENT! " + item.testname + " Was: "\ + + str(item.previous) + " Is: " + str(item.current) + " Change: "\ + + str(abs(item.percentage)) + "%. Revision: " + item.revision\ + + bcolors.ENDC) + if firsttime != []: + for item in firsttime: + print(bcolors.PURPLE + "First time test! " + item.testname +\ + " Took: " + str(item.real) + " seconds. Revision: " +\ + item.revision + bcolors.ENDC) + + +all_files = os.listdir(LOGDIR) +regressed = [] +better = [] +unchanged = [] +firsttime = [] + +#Go through all log files and find which tests have performed better. +for logfile in all_files: + (line1, line2) = getLastTwoLines(logfile, LOGDIR) + log1 = processLogLine(line1) + if line2 == '\n': # Empty line, only one test ever run + firsttime.append(log1) + continue + log2 = processLogLine(line2) + res = Result(log1.testname, log1.real, log2.real, log2.revision,\ + log2.branch, log1.revision, log1.branch) + if res.percentage < -PERCENTAGE: + regressed.append(res) + elif res.change > PERCENTAGE: + better.append(res) + else: + unchanged.append(res) + +printResults(regressed, better, unchanged, firsttime) diff --git a/contrib/moses-speedtest/cronjob b/contrib/moses-speedtest/cronjob new file mode 100644 index 000000000..4f7183a48 --- /dev/null +++ b/contrib/moses-speedtest/cronjob @@ -0,0 +1,7 @@ +#!/bin/sh +cd /home/moses-speedtest/phrase_tables + +python3 runtests.py -c testsuite_config #Run the tests. +python3 html_gen.py testsuite_config #Generate html + +cp index.html /fs/thor4/html/www/speed-test/ #Update the html \ No newline at end of file diff --git a/contrib/moses-speedtest/helpers/README.md b/contrib/moses-speedtest/helpers/README.md new file mode 100644 index 000000000..87efbc78f --- /dev/null +++ b/contrib/moses-speedtest/helpers/README.md @@ -0,0 +1,5 @@ +###Helpers + +This is a python script that basically gives you the equivalent of: +```echo 3 > /proc/sys/vm/drop_caches``` +You need to set it up so it is executed with root access without needing a password so that the tests can be automated. \ No newline at end of file diff --git a/contrib/moses-speedtest/helpers/sys_drop_caches.py b/contrib/moses-speedtest/helpers/sys_drop_caches.py new file mode 100644 index 000000000..d4796e090 --- /dev/null +++ b/contrib/moses-speedtest/helpers/sys_drop_caches.py @@ -0,0 +1,22 @@ +#!/usr/bin/spython +from sys import argv, stderr, exit +from os import linesep as ls +procfile = "/proc/sys/vm/drop_caches" +options = ["1","2","3"] +flush_type = None +try: + flush_type = argv[1][0:1] + if not flush_type in options: + raise IndexError, "not in options" + with open(procfile, "w") as f: + f.write("%s%s" % (flush_type,ls)) + exit(0) +except IndexError, e: + stderr.write("Argument %s required.%s" % (options, ls)) +except IOError, e: + stderr.write("Error writing to file.%s" % ls) +except StandardError, e: + stderr.write("Unknown Error.%s" % ls) + +exit(1) + diff --git a/contrib/moses-speedtest/html/README.md b/contrib/moses-speedtest/html/README.md new file mode 100644 index 000000000..342a8cedf --- /dev/null +++ b/contrib/moses-speedtest/html/README.md @@ -0,0 +1,5 @@ +###HTML files. + +_index.html_ is a sample generated file by this testsuite. + +_style.css_ should be placed in the html directory in which _index.html_ will be placed in order to visualize the test results in a browser. diff --git a/contrib/moses-speedtest/html/index.html b/contrib/moses-speedtest/html/index.html new file mode 100644 index 000000000..fc75b1028 --- /dev/null +++ b/contrib/moses-speedtest/html/index.html @@ -0,0 +1,32 @@ + + +Moses speed testing +Basebranch: RELEASE-2.1 Revision: c977ca2f434ed6f12a352806c088061c492b1676 + + + + + + + + + + + + + + + + + + + + + + + + + + + +
DateTimeTestnameRevisionBranchTimePrevtimePrevrevChange (%)Time (Basebranch)Change (%, Basebranch)Time (Days -2)Change (%, Days -2)Time (Days -3)Change (%, Days -3)Time (Days -4)Change (%, Days -4)Time (Days -5)Change (%, Days -5)Time (Days -6)Change (%, Days -6)Time (Days -7)Change (%, Days -7)Time (Days -14)Change (%, Days -14)Time (Years -1)Change (%, Years -1)
10.06.201410:27:57ondisk_minreord_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster21.3621.49169c3fce383bc66ae580884bfa72d60712beffef0.00625.890.1699N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:29:38minpt_reord_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster9.739.52169c3fce383bc66ae580884bfa72d60712beffef-0.022112.20.2197N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:22:32ondisk_hierarchical_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster25.7325.77169c3fce383bc66ae580884bfa72d60712beffef0.001633.630.2337N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:22:06ondisk_hierarchical_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster83.282.6169c3fce383bc66ae580884bfa72d60712beffef-0.0073127.590.3526N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:28:57binary_reord_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster24.5424.85169c3fce383bc66ae580884bfa72d60712beffef0.012529.090.1458N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:28:08ondisk_minreord_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster10.7110.54169c3fce383bc66ae580884bfa72d60712beffef-0.016114.820.2888N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:30:00binary_minreord_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster20.8220.77169c3fce383bc66ae580884bfa72d60712beffef-0.002425.770.194N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:27:35score.hiero_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster131.37130.63169c3fce383bc66ae580884bfa72d60712beffef-0.0057141.850.0791N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:29:10binary_reord_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster13.4113.4169c3fce383bc66ae580884bfa72d60712beffef-0.000718.120.2605N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:29:28minpt_reord_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster17.4617.37169c3fce383bc66ae580884bfa72d60712beffef-0.005220.00.1315N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:28:22minpt_minreord_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster13.7513.56169c3fce383bc66ae580884bfa72d60712beffef-0.01417.190.2112N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:22:59ondisk_reord_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster25.2825.0169c3fce383bc66ae580884bfa72d60712beffef-0.011229.110.1412N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:28:31minpt_minreord_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster8.638.6169c3fce383bc66ae580884bfa72d60712beffef-0.003511.780.2699N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:23:10ondisk_reord_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster11.5711.59169c3fce383bc66ae580884bfa72d60712beffef0.001715.40.2474N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:25:24score.hiero_vanilla169c3fce383bc66ae580884bfa72d60712beffefmaster132.33130.02169c3fce383bc66ae580884bfa72d60712beffef-0.0178141.350.0802N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
10.06.201410:30:12binary_minreord_vanilla_cached169c3fce383bc66ae580884bfa72d60712beffefmaster12.4712.61169c3fce383bc66ae580884bfa72d60712beffef0.011117.890.2951N/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/AN/A
diff --git a/contrib/moses-speedtest/html/style.css b/contrib/moses-speedtest/html/style.css new file mode 100644 index 000000000..16221f91f --- /dev/null +++ b/contrib/moses-speedtest/html/style.css @@ -0,0 +1,21 @@ +table,th,td +{ +border:1px solid black; + border-collapse:collapse +} + +tr:nth-child(odd) { + background-color: Gainsboro; +} + +.better { + color: Green; +} + +.worse { + color: Red; +} + +.unchanged { + color: SkyBlue; +} \ No newline at end of file diff --git a/contrib/moses-speedtest/html_gen.py b/contrib/moses-speedtest/html_gen.py new file mode 100644 index 000000000..80e88329c --- /dev/null +++ b/contrib/moses-speedtest/html_gen.py @@ -0,0 +1,192 @@ +"""Generates HTML page containing the testresults""" +from testsuite_common import Result, processLogLine, getLastTwoLines +from runtests import parse_testconfig +import os +import sys + +from datetime import datetime, timedelta + +HTML_HEADING = """ + +Moses speed testing +""" +HTML_ENDING = "\n" + +TABLE_HEADING = """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + """ + +def get_prev_days(date, numdays): + """Gets the date numdays previous days so that we could search for + that test in the config file""" + date_obj = datetime.strptime(date, '%d.%m.%Y').date() + past_date = date_obj - timedelta(days=numdays) + return past_date.strftime('%d.%m.%Y') + +def gather_necessary_lines(logfile, date): + """Gathers the necessary lines corresponding to past dates + and parses them if they exist""" + #Get a dictionary of dates + dates = {} + dates[get_prev_days(date, 2)] = ('-2', None) + dates[get_prev_days(date, 3)] = ('-3', None) + dates[get_prev_days(date, 4)] = ('-4', None) + dates[get_prev_days(date, 5)] = ('-5', None) + dates[get_prev_days(date, 6)] = ('-6', None) + dates[get_prev_days(date, 7)] = ('-7', None) + dates[get_prev_days(date, 14)] = ('-14', None) + dates[get_prev_days(date, 365)] = ('-365', None) + + openfile = open(logfile, 'r') + for line in openfile: + if line.split()[0] in dates.keys(): + day = dates[line.split()[0]][0] + dates[line.split()[0]] = (day, processLogLine(line)) + openfile.close() + return dates + +def append_date_to_table(resline): + """Appends past dates to the html""" + cur_html = '' + + if resline.percentage > 0.05: #If we have improvement of more than 5% + cur_html = cur_html + '' + elif resline.percentage < -0.05: #We have a regression of more than 5% + cur_html = cur_html + '' + else: + cur_html = cur_html + '' + return cur_html + +def compare_rev(filename, rev1, rev2, branch1=False, branch2=False): + """Compare the test results of two lines. We can specify either a + revision or a branch for comparison. The first rev should be the + base version and the second revision should be the later version""" + + #In the log file the index of the revision is 2 but the index of + #the branch is 12. Alternate those depending on whether we are looking + #for a specific revision or branch. + firstidx = 2 + secondidx = 2 + if branch1 == True: + firstidx = 12 + if branch2 == True: + secondidx = 12 + + rev1line = '' + rev2line = '' + resfile = open(filename, 'r') + for line in resfile: + if rev1 == line.split()[firstidx]: + rev1line = line + elif rev2 == line.split()[secondidx]: + rev2line = line + if rev1line != '' and rev2line != '': + break + resfile.close() + if rev1line == '': + raise ValueError('Revision ' + rev1 + " was not found!") + if rev2line == '': + raise ValueError('Revision ' + rev2 + " was not found!") + + logLine1 = processLogLine(rev1line) + logLine2 = processLogLine(rev2line) + res = Result(logLine1.testname, logLine1.real, logLine2.real,\ + logLine2.revision, logLine2.branch, logLine1.revision, logLine1.branch) + + return res + +def produce_html(path, global_config): + """Produces html file for the report.""" + html = '' #The table HTML + for filenam in os.listdir(global_config.testlogs): + #Generate html for the newest two lines + #Get the lines from the config file + (ll1, ll2) = getLastTwoLines(filenam, global_config.testlogs) + logLine1 = processLogLine(ll1) + logLine2 = processLogLine(ll2) #This is the life from the latest revision + + #Generate html + res1 = Result(logLine1.testname, logLine1.real, logLine2.real,\ + logLine2.revision, logLine2.branch, logLine1.revision, logLine1.branch) + html = html + '' + + #Add fancy colours depending on the change + if res1.percentage > 0.05: #If we have improvement of more than 5% + html = html + '' + elif res1.percentage < -0.05: #We have a regression of more than 5% + html = html + '' + else: + html = html + '' + + #Get comparison against the base version + filenam = global_config.testlogs + '/' + filenam #Get proper directory + res2 = compare_rev(filenam, global_config.basebranch, res1.revision, branch1=True) + html = html + '' + + #Add fancy colours depending on the change + if res2.percentage > 0.05: #If we have improvement of more than 5% + html = html + '' + elif res2.percentage < -0.05: #We have a regression of more than 5% + html = html + '' + else: + html = html + '' + + #Add extra dates comparison dating from the beginning of time if they exist + past_dates = list(range(2, 8)) + past_dates.append(14) + past_dates.append(365) # Get the 1 year ago day + linesdict = gather_necessary_lines(filenam, logLine2.date) + + for days in past_dates: + act_date = get_prev_days(logLine2.date, days) + if linesdict[act_date][1] is not None: + logline_date = linesdict[act_date][1] + restemp = Result(logline_date.testname, logline_date.real, logLine2.real,\ + logLine2.revision, logLine2.branch, logline_date.revision, logline_date.branch) + html = html + append_date_to_table(restemp) + else: + html = html + '' + + + + html = html + '' #End row + + #Write out the file + basebranch_info = 'Basebranch: ' + res2.prevbranch + ' Revision: ' +\ + res2.prevrev + '' + writeoutstr = HTML_HEADING + basebranch_info + TABLE_HEADING + html + HTML_ENDING + writefile = open(path, 'w') + writefile.write(writeoutstr) + writefile.close() + +if __name__ == '__main__': + CONFIG = parse_testconfig(sys.argv[1]) + produce_html('index.html', CONFIG) diff --git a/contrib/moses-speedtest/runtests.py b/contrib/moses-speedtest/runtests.py new file mode 100644 index 000000000..0978c8ef2 --- /dev/null +++ b/contrib/moses-speedtest/runtests.py @@ -0,0 +1,293 @@ +"""Given a config file, runs tests""" +import os +import subprocess +import time +from argparse import ArgumentParser +from testsuite_common import processLogLine + +def parse_cmd(): + """Parse the command line arguments""" + description = "A python based speedtest suite for moses." + parser = ArgumentParser(description=description) + parser.add_argument("-c", "--configfile", action="store",\ + dest="configfile", required=True,\ + help="Specify test config file") + parser.add_argument("-s", "--singletest", action="store",\ + dest="singletestdir", default=None,\ + help="Single test name directory. Specify directory name,\ + not full path!") + parser.add_argument("-r", "--revision", action="store",\ + dest="revision", default=None,\ + help="Specify a specific revison for the test.") + parser.add_argument("-b", "--branch", action="store",\ + dest="branch", default=None,\ + help="Specify a branch for the test.") + + arguments = parser.parse_args() + return arguments + +def repoinit(testconfig): + """Determines revision and sets up the repo.""" + revision = '' + #Update the repo + os.chdir(testconfig.repo) + #Checkout specific branch, else maintain main branch + if testconfig.branch != 'master': + subprocess.call(['git', 'checkout', testconfig.branch]) + rev, _ = subprocess.Popen(['git', 'rev-parse', 'HEAD'],\ + stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate() + revision = str(rev).replace("\\n'", '').replace("b'", '') + else: + subprocess.call(['git checkout master'], shell=True) + + #Check a specific revision. Else checkout master. + if testconfig.revision: + subprocess.call(['git', 'checkout', testconfig.revision]) + revision = testconfig.revision + elif testconfig.branch == 'master': + subprocess.call(['git pull'], shell=True) + rev, _ = subprocess.Popen(['git rev-parse HEAD'], stdout=subprocess.PIPE,\ + stderr=subprocess.PIPE, shell=True).communicate() + revision = str(rev).replace("\\n'", '').replace("b'", '') + + return revision + +class Configuration: + """A simple class to hold all of the configuration constatns""" + def __init__(self, repo, drop_caches, tests, testlogs, basebranch, baserev): + self.repo = repo + self.drop_caches = drop_caches + self.tests = tests + self.testlogs = testlogs + self.basebranch = basebranch + self.baserev = baserev + self.singletest = None + self.revision = None + self.branch = 'master' # Default branch + + def additional_args(self, singletest, revision, branch): + """Additional configuration from command line arguments""" + self.singletest = singletest + if revision is not None: + self.revision = revision + if branch is not None: + self.branch = branch + + def set_revision(self, revision): + """Sets the current revision that is being tested""" + self.revision = revision + + +class Test: + """A simple class to contain all information about tests""" + def __init__(self, name, command, ldopts, permutations): + self.name = name + self.command = command + self.ldopts = ldopts.replace(' ', '').split(',') #Not tested yet + self.permutations = permutations + +def parse_configfile(conffile, testdir, moses_repo): + """Parses the config file""" + command, ldopts = '', '' + permutations = [] + fileopen = open(conffile, 'r') + for line in fileopen: + line = line.split('#')[0] # Discard comments + if line == '' or line == '\n': + continue # Discard lines with comments only and empty lines + opt, args = line.split(' ', 1) # Get arguments + + if opt == 'Command:': + command = args.replace('\n', '') + command = moses_repo + '/bin/' + command + elif opt == 'LDPRE:': + ldopts = args.replace('\n', '') + elif opt == 'Variants:': + permutations = args.replace('\n', '').replace(' ', '').split(',') + else: + raise ValueError('Unrecognized option ' + opt) + #We use the testdir as the name. + testcase = Test(testdir, command, ldopts, permutations) + fileopen.close() + return testcase + +def parse_testconfig(conffile): + """Parses the config file for the whole testsuite.""" + repo_path, drop_caches, tests_dir, testlog_dir = '', '', '', '' + basebranch, baserev = '', '' + fileopen = open(conffile, 'r') + for line in fileopen: + line = line.split('#')[0] # Discard comments + if line == '' or line == '\n': + continue # Discard lines with comments only and empty lines + opt, args = line.split(' ', 1) # Get arguments + if opt == 'MOSES_REPO_PATH:': + repo_path = args.replace('\n', '') + elif opt == 'DROP_CACHES_COMM:': + drop_caches = args.replace('\n', '') + elif opt == 'TEST_DIR:': + tests_dir = args.replace('\n', '') + elif opt == 'TEST_LOG_DIR:': + testlog_dir = args.replace('\n', '') + elif opt == 'BASEBRANCH:': + basebranch = args.replace('\n', '') + elif opt == 'BASEREV:': + baserev = args.replace('\n', '') + else: + raise ValueError('Unrecognized option ' + opt) + config = Configuration(repo_path, drop_caches, tests_dir, testlog_dir,\ + basebranch, baserev) + fileopen.close() + return config + +def get_config(): + """Builds the config object with all necessary attributes""" + args = parse_cmd() + config = parse_testconfig(args.configfile) + config.additional_args(args.singletestdir, args.revision, args.branch) + revision = repoinit(config) + config.set_revision(revision) + return config + +def check_for_basever(testlogfile, basebranch): + """Checks if the base revision is present in the testlogs""" + filetoopen = open(testlogfile, 'r') + for line in filetoopen: + templine = processLogLine(line) + if templine.branch == basebranch: + return True + return False + +def split_time(filename): + """Splits the output of the time function into seperate parts. + We will write time to file, because many programs output to + stderr which makes it difficult to get only the exact results we need.""" + timefile = open(filename, 'r') + realtime = float(timefile.readline().replace('\n', '').split()[1]) + usertime = float(timefile.readline().replace('\n', '').split()[1]) + systime = float(timefile.readline().replace('\n', '').split()[1]) + timefile.close() + + return (realtime, usertime, systime) + + +def write_log(time_file, logname, config): + """Writes to a logfile""" + log_write = open(config.testlogs + '/' + logname, 'a') # Open logfile + date_run = time.strftime("%d.%m.%Y %H:%M:%S") # Get the time of the test + realtime, usertime, systime = split_time(time_file) # Get the times in a nice form + + # Append everything to a log file. + writestr = date_run + " " + config.revision + " Testname: " + logname +\ + " RealTime: " + str(realtime) + " UserTime: " + str(usertime) +\ + " SystemTime: " + str(systime) + " Branch: " + config.branch +'\n' + log_write.write(writestr) + log_write.close() + + +def execute_tests(testcase, cur_directory, config): + """Executes timed tests based on the config file""" + #Figure out the order of which tests must be executed. + #Change to the current test directory + os.chdir(config.tests + '/' + cur_directory) + #Clear caches + subprocess.call(['sync'], shell=True) + subprocess.call([config.drop_caches], shell=True) + #Perform vanilla test and if a cached test exists - as well + print(testcase.name) + if 'vanilla' in testcase.permutations: + print(testcase.command) + subprocess.Popen(['time -p -o /tmp/time_moses_tests ' + testcase.command], stdout=None,\ + stderr=subprocess.PIPE, shell=True).communicate() + write_log('/tmp/time_moses_tests', testcase.name + '_vanilla', config) + if 'cached' in testcase.permutations: + subprocess.Popen(['time -p -o /tmp/time_moses_tests ' + testcase.command], stdout=None,\ + stderr=None, shell=True).communicate() + write_log('/tmp/time_moses_tests', testcase.name + '_vanilla_cached', config) + + #Now perform LD_PRELOAD tests + if 'ldpre' in testcase.permutations: + for opt in testcase.ldopts: + #Clear caches + subprocess.call(['sync'], shell=True) + subprocess.call([config.drop_caches], shell=True) + + #test + subprocess.Popen(['LD_PRELOAD ' + opt + ' time -p -o /tmp/time_moses_tests ' + testcase.command], stdout=None,\ + stderr=None, shell=True).communicate() + write_log('/tmp/time_moses_tests', testcase.name + '_ldpre_' + opt, config) + if 'cached' in testcase.permutations: + subprocess.Popen(['LD_PRELOAD ' + opt + ' time -p -o /tmp/time_moses_tests ' + testcase.command], stdout=None,\ + stderr=None, shell=True).communicate() + write_log('/tmp/time_moses_tests', testcase.name + '_ldpre_' +opt +'_cached', config) + +# Go through all the test directories and executes tests +if __name__ == '__main__': + CONFIG = get_config() + ALL_DIR = os.listdir(CONFIG.tests) + + #We should first check if any of the tests is run for the first time. + #If some of them are run for the first time we should first get their + #time with the base version (usually the previous release) + FIRSTTIME = [] + TESTLOGS = [] + #Strip filenames of test underscores + for listline in os.listdir(CONFIG.testlogs): + listline = listline.replace('_vanilla', '') + listline = listline.replace('_cached', '') + listline = listline.replace('_ldpre', '') + TESTLOGS.append(listline) + for directory in ALL_DIR: + if directory not in TESTLOGS: + FIRSTTIME.append(directory) + + #Sometimes even though we have the log files, we will need to rerun them + #Against a base version, because we require a different baseversion (for + #example when a new version of Moses is released.) Therefore we should + #Check if the version of Moses that we have as a base version is in all + #of the log files. + + for logfile in os.listdir(CONFIG.testlogs): + logfile_name = CONFIG.testlogs + '/' + logfile + if not check_for_basever(logfile_name, CONFIG.basebranch): + logfile = logfile.replace('_vanilla', '') + logfile = logfile.replace('_cached', '') + logfile = logfile.replace('_ldpre', '') + FIRSTTIME.append(logfile) + FIRSTTIME = list(set(FIRSTTIME)) #Deduplicate + + if FIRSTTIME != []: + #Create a new configuration for base version tests: + BASECONFIG = Configuration(CONFIG.repo, CONFIG.drop_caches,\ + CONFIG.tests, CONFIG.testlogs, CONFIG.basebranch,\ + CONFIG.baserev) + BASECONFIG.additional_args(None, CONFIG.baserev, CONFIG.basebranch) + #Set up the repository and get its revision: + REVISION = repoinit(BASECONFIG) + BASECONFIG.set_revision(REVISION) + #Build + os.chdir(BASECONFIG.repo) + subprocess.call(['./previous.sh'], shell=True) + + #Perform tests + for directory in FIRSTTIME: + cur_testcase = parse_configfile(BASECONFIG.tests + '/' + directory +\ + '/config', directory, BASECONFIG.repo) + execute_tests(cur_testcase, directory, BASECONFIG) + + #Reset back the repository to the normal configuration + repoinit(CONFIG) + + #Builds moses + os.chdir(CONFIG.repo) + subprocess.call(['./previous.sh'], shell=True) + + if CONFIG.singletest: + TESTCASE = parse_configfile(CONFIG.tests + '/' +\ + CONFIG.singletest + '/config', CONFIG.singletest, CONFIG.repo) + execute_tests(TESTCASE, CONFIG.singletest, CONFIG) + else: + for directory in ALL_DIR: + cur_testcase = parse_configfile(CONFIG.tests + '/' + directory +\ + '/config', directory, CONFIG.repo) + execute_tests(cur_testcase, directory, CONFIG) diff --git a/contrib/moses-speedtest/sys_drop_caches.py b/contrib/moses-speedtest/sys_drop_caches.py new file mode 100644 index 000000000..d4796e090 --- /dev/null +++ b/contrib/moses-speedtest/sys_drop_caches.py @@ -0,0 +1,22 @@ +#!/usr/bin/spython +from sys import argv, stderr, exit +from os import linesep as ls +procfile = "/proc/sys/vm/drop_caches" +options = ["1","2","3"] +flush_type = None +try: + flush_type = argv[1][0:1] + if not flush_type in options: + raise IndexError, "not in options" + with open(procfile, "w") as f: + f.write("%s%s" % (flush_type,ls)) + exit(0) +except IndexError, e: + stderr.write("Argument %s required.%s" % (options, ls)) +except IOError, e: + stderr.write("Error writing to file.%s" % ls) +except StandardError, e: + stderr.write("Unknown Error.%s" % ls) + +exit(1) + diff --git a/contrib/moses-speedtest/test_config b/contrib/moses-speedtest/test_config new file mode 100644 index 000000000..4a480f496 --- /dev/null +++ b/contrib/moses-speedtest/test_config @@ -0,0 +1,3 @@ +Command: moses -f ... -i fff #Looks for the command in the /bin directory of the repo specified in the testsuite_config +LDPRE: ldpreloads #Comma separated LD_LIBRARY_PATH:/, +Variants: vanilla, cached, ldpre #Can't have cached without ldpre or vanilla diff --git a/contrib/moses-speedtest/testsuite_common.py b/contrib/moses-speedtest/testsuite_common.py new file mode 100644 index 000000000..be96f98b5 --- /dev/null +++ b/contrib/moses-speedtest/testsuite_common.py @@ -0,0 +1,54 @@ +"""Common functions of the testsuitce""" +import os +#Clour constants +class bcolors: + PURPLE = '\033[95m' + BLUE = '\033[94m' + GREEN = '\033[92m' + YELLOW = '\033[93m' + RED = '\033[91m' + ENDC = '\033[0m' + +class LogLine: + """A class to contain logfile line""" + def __init__(self, date, time, revision, testname, real, user, system, branch): + self.date = date + self.time = time + self.revision = revision + self.testname = testname + self.real = real + self.system = system + self.user = user + self.branch = branch + +class Result: + """A class to contain results of benchmarking""" + def __init__(self, testname, previous, current, revision, branch, prevrev, prevbranch): + self.testname = testname + self.previous = previous + self.current = current + self.change = previous - current + self.revision = revision + self.branch = branch + self.prevbranch = prevbranch + self.prevrev = prevrev + #Produce a percentage with fewer digits + self.percentage = float(format(1 - current/previous, '.4f')) + +def processLogLine(logline): + """Parses the log line into a nice datastructure""" + logline = logline.split() + log = LogLine(logline[0], logline[1], logline[2], logline[4],\ + float(logline[6]), float(logline[8]), float(logline[10]), logline[12]) + return log + +def getLastTwoLines(filename, logdir): + """Just a call to tail to get the diff between the last two runs""" + try: + line1, line2 = os.popen("tail -n2 " + logdir + '/' + filename) + except ValueError: #Check for new tests + tempfile = open(logdir + '/' + filename) + line1 = tempfile.readline() + tempfile.close() + return (line1, '\n') + return (line1, line2) diff --git a/contrib/moses-speedtest/testsuite_config b/contrib/moses-speedtest/testsuite_config new file mode 100644 index 000000000..b6ad6181c --- /dev/null +++ b/contrib/moses-speedtest/testsuite_config @@ -0,0 +1,5 @@ +MOSES_REPO_PATH: /home/moses-speedtest/moses-standard/mosesdecoder +DROP_CACHES_COMM: sys_drop_caches 3 +TEST_DIR: /home/moses-speedtest/phrase_tables/tests +TEST_LOG_DIR: /home/moses-speedtest/phrase_tables/testlogs +BASEBRANCH: RELEASE-2.1.1 \ No newline at end of file diff --git a/contrib/other-builds/CreateOnDiskPt/.project b/contrib/other-builds/CreateOnDiskPt/.project new file mode 100644 index 000000000..5bca3b8f2 --- /dev/null +++ b/contrib/other-builds/CreateOnDiskPt/.project @@ -0,0 +1,44 @@ + + + CreateOnDiskPt + + + lm + moses + OnDiskPt + search + util + + + + org.eclipse.cdt.managedbuilder.core.genmakebuilder + clean,full,incremental, + + + + + org.eclipse.cdt.managedbuilder.core.ScannerConfigBuilder + full,incremental, + + + + + + org.eclipse.cdt.core.cnature + org.eclipse.cdt.core.ccnature + org.eclipse.cdt.managedbuilder.core.managedBuildNature + org.eclipse.cdt.managedbuilder.core.ScannerConfigNature + + + + Main.cpp + 1 + PARENT-3-PROJECT_LOC/OnDiskPt/Main.cpp + + + Main.h + 1 + PARENT-3-PROJECT_LOC/OnDiskPt/Main.h + + + diff --git a/contrib/other-builds/OnDiskPt/.cproject b/contrib/other-builds/OnDiskPt/.cproject deleted file mode 100644 index e32a5baea..000000000 --- a/contrib/other-builds/OnDiskPt/.cproject +++ /dev/null @@ -1,146 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/extract-ordering/.project b/contrib/other-builds/consolidate/.project similarity index 78% rename from contrib/other-builds/extract-ordering/.project rename to contrib/other-builds/consolidate/.project index f95b064b7..4095862b4 100644 --- a/contrib/other-builds/extract-ordering/.project +++ b/contrib/other-builds/consolidate/.project @@ -1,6 +1,6 @@ - extract-ordering + consolidate @@ -46,19 +46,9 @@ PARENT-3-PROJECT_LOC/phrase-extract/OutputFileStream.h - SentenceAlignment.cpp + consolidate-main.cpp 1 - PARENT-3-PROJECT_LOC/phrase-extract/SentenceAlignment.cpp - - - SentenceAlignment.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/SentenceAlignment.h - - - extract-ordering-main.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ordering-main.cpp + PARENT-3-PROJECT_LOC/phrase-extract/consolidate-main.cpp tables-core.cpp diff --git a/contrib/other-builds/extract-ghkm/.cproject b/contrib/other-builds/extract-ghkm/.cproject deleted file mode 100644 index 3d8b83618..000000000 --- a/contrib/other-builds/extract-ghkm/.cproject +++ /dev/null @@ -1,138 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/extract-ghkm/.project b/contrib/other-builds/extract-ghkm/.project index b7c40f069..f9570120b 100644 --- a/contrib/other-builds/extract-ghkm/.project +++ b/contrib/other-builds/extract-ghkm/.project @@ -26,49 +26,19 @@ - Alignment.cpp + Hole.h 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Alignment.cpp + PARENT-3-PROJECT_LOC/phrase-extract/Hole.h - Alignment.h + HoleCollection.cpp 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Alignment.h + PARENT-3-PROJECT_LOC/phrase-extract/HoleCollection.cpp - AlignmentGraph.cpp + HoleCollection.h 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/AlignmentGraph.cpp - - - AlignmentGraph.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/AlignmentGraph.h - - - ComposedRule.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ComposedRule.cpp - - - ComposedRule.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ComposedRule.h - - - Exception.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Exception.h - - - ExtractGHKM.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ExtractGHKM.cpp - - - ExtractGHKM.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ExtractGHKM.h + PARENT-3-PROJECT_LOC/phrase-extract/HoleCollection.h InputFileStream.cpp @@ -80,31 +50,6 @@ 1 PARENT-3-PROJECT_LOC/phrase-extract/InputFileStream.h - - Jamfile - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Jamfile - - - Main.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Main.cpp - - - Node.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Node.cpp - - - Node.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Node.h - - - Options.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Options.h - OutputFileStream.cpp 1 @@ -116,54 +61,24 @@ PARENT-3-PROJECT_LOC/phrase-extract/OutputFileStream.h - ParseTree.cpp + SentenceAlignment.cpp 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ParseTree.cpp + PARENT-3-PROJECT_LOC/phrase-extract/SentenceAlignment.cpp - ParseTree.h + SentenceAlignment.h 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ParseTree.h + PARENT-3-PROJECT_LOC/phrase-extract/SentenceAlignment.h - ScfgRule.cpp + SentenceAlignmentWithSyntax.cpp 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ScfgRule.cpp + PARENT-3-PROJECT_LOC/phrase-extract/SentenceAlignmentWithSyntax.cpp - ScfgRule.h + SentenceAlignmentWithSyntax.h 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ScfgRule.h - - - ScfgRuleWriter.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp - - - ScfgRuleWriter.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/ScfgRuleWriter.h - - - Span.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Span.cpp - - - Span.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Span.h - - - Subgraph.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Subgraph.cpp - - - Subgraph.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/Subgraph.h + PARENT-3-PROJECT_LOC/phrase-extract/SentenceAlignmentWithSyntax.h SyntaxTree.cpp @@ -186,14 +101,9 @@ PARENT-3-PROJECT_LOC/phrase-extract/XmlTree.h - XmlTreeParser.cpp + extract-rules-main.cpp 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/XmlTreeParser.cpp - - - XmlTreeParser.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-ghkm/XmlTreeParser.h + PARENT-3-PROJECT_LOC/phrase-extract/extract-rules-main.cpp tables-core.cpp diff --git a/contrib/other-builds/extract-mixed-syntax/.cproject b/contrib/other-builds/extract-mixed-syntax/.cproject deleted file mode 100644 index d5ea4ecdb..000000000 --- a/contrib/other-builds/extract-mixed-syntax/.cproject +++ /dev/null @@ -1,134 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/extract-mixed-syntax/AlignedSentence.cpp b/contrib/other-builds/extract-mixed-syntax/AlignedSentence.cpp new file mode 100644 index 000000000..0f00d0bbf --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/AlignedSentence.cpp @@ -0,0 +1,189 @@ +/* + * AlignedSentence.cpp + * + * Created on: 18 Feb 2014 + * Author: s0565741 + */ + +#include +#include "moses/Util.h" +#include "AlignedSentence.h" +#include "Parameter.h" + +using namespace std; + + +///////////////////////////////////////////////////////////////////////////////// +AlignedSentence::AlignedSentence(int lineNum, + const std::string &source, + const std::string &target, + const std::string &alignment) +:m_lineNum(lineNum) +{ + PopulateWordVec(m_source, source); + PopulateWordVec(m_target, target); + PopulateAlignment(alignment); +} + +AlignedSentence::~AlignedSentence() { + Moses::RemoveAllInColl(m_source); + Moses::RemoveAllInColl(m_target); +} + +void AlignedSentence::PopulateWordVec(Phrase &vec, const std::string &line) +{ + std::vector toks; + Moses::Tokenize(toks, line); + + vec.resize(toks.size()); + for (size_t i = 0; i < vec.size(); ++i) { + const string &tok = toks[i]; + Word *word = new Word(i, tok); + vec[i] = word; + } +} + +void AlignedSentence::PopulateAlignment(const std::string &line) +{ + vector alignStr; + Moses::Tokenize(alignStr, line); + + for (size_t i = 0; i < alignStr.size(); ++i) { + vector alignPair; + Moses::Tokenize(alignPair, alignStr[i], "-"); + assert(alignPair.size() == 2); + + int sourcePos = alignPair[0]; + int targetPos = alignPair[1]; + + if (sourcePos >= m_source.size()) { + cerr << "ERROR1:AlignedSentence=" << Debug() << endl; + cerr << "m_source=" << m_source.size() << endl; + abort(); + } + assert(sourcePos < m_source.size()); + assert(targetPos < m_target.size()); + Word *sourceWord = m_source[sourcePos]; + Word *targetWord = m_target[targetPos]; + + sourceWord->AddAlignment(targetWord); + targetWord->AddAlignment(sourceWord); + } +} + +std::string AlignedSentence::Debug() const +{ + stringstream out; + out << "m_lineNum:"; + out << m_lineNum; + out << endl; + + out << "m_source:"; + out << m_source.Debug(); + out << endl; + + out << "m_target:"; + out << m_target.Debug(); + out << endl; + + out << "consistent phrases:" << endl; + out << m_consistentPhrases.Debug(); + out << endl; + + return out.str(); +} + +std::vector AlignedSentence::GetSourceAlignmentCount() const +{ + vector ret(m_source.size()); + + for (size_t i = 0; i < m_source.size(); ++i) { + const Word &word = *m_source[i]; + ret[i] = word.GetAlignmentIndex().size(); + } + return ret; +} + +void AlignedSentence::Create(const Parameter ¶ms) +{ + CreateConsistentPhrases(params); + m_consistentPhrases.AddHieroNonTerms(params); +} + +void AlignedSentence::CreateConsistentPhrases(const Parameter ¶ms) +{ + int countT = m_target.size(); + int countS = m_source.size(); + + m_consistentPhrases.Initialize(countS); + + // check alignments for target phrase startT...endT + for(int lengthT=1; + lengthT <= params.maxSpan && lengthT <= countT; + lengthT++) { + for(int startT=0; startT < countT-(lengthT-1); startT++) { + + // that's nice to have + int endT = startT + lengthT - 1; + + // find find aligned source words + // first: find minimum and maximum source word + int minS = 9999; + int maxS = -1; + vector< int > usedS = GetSourceAlignmentCount(); + for(int ti=startT; ti<=endT; ti++) { + const Word &word = *m_target[ti]; + const std::set &alignment = word.GetAlignmentIndex(); + + std::set::const_iterator iterAlign; + for(iterAlign = alignment.begin(); iterAlign != alignment.end(); ++iterAlign) { + int si = *iterAlign; + if (simaxS) { + maxS = si; + } + usedS[ si ]--; + } + } + + // unaligned phrases are not allowed + if( maxS == -1 ) + continue; + + // source phrase has to be within limits + if( maxS-minS >= params.maxSpan ) + continue; + + // check if source words are aligned to out of bound target words + bool out_of_bounds = false; + for(int si=minS; si<=maxS && !out_of_bounds; si++) + if (usedS[si]>0) { + out_of_bounds = true; + } + + // if out of bound, you gotta go + if (out_of_bounds) + continue; + + // done with all the checks, lets go over all consistent phrase pairs + // start point of source phrase may retreat over unaligned + for(int startS=minS; + (startS>=0 && + startS>maxS - params.maxSpan && // within length limit + (startS==minS || m_source[startS]->GetAlignment().size()==0)); // unaligned + startS--) { + // end point of source phrase may advance over unaligned + for(int endS=maxS; + (endSGetAlignment().size()==0)); // unaligned + endS++) { + + // take note that this is a valid phrase alignment + m_consistentPhrases.Add(startS, endS, startT, endT, params); + } + } + } + } +} diff --git a/contrib/other-builds/extract-mixed-syntax/AlignedSentence.h b/contrib/other-builds/extract-mixed-syntax/AlignedSentence.h new file mode 100644 index 000000000..915bdf90c --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/AlignedSentence.h @@ -0,0 +1,51 @@ +/* + * AlignedSentence.h + * + * Created on: 18 Feb 2014 + * Author: s0565741 + */ +#pragma once + +#include +#include +#include "ConsistentPhrases.h" +#include "Phrase.h" +#include "moses/TypeDef.h" + +class Parameter; + +class AlignedSentence { +public: + AlignedSentence(int lineNum) + :m_lineNum(lineNum) + {} + + AlignedSentence(int lineNum, + const std::string &source, + const std::string &target, + const std::string &alignment); + virtual ~AlignedSentence(); + virtual void Create(const Parameter ¶ms); + + const Phrase &GetPhrase(Moses::FactorDirection direction) const + { return (direction == Moses::Input) ? m_source : m_target; } + + const ConsistentPhrases &GetConsistentPhrases() const + { return m_consistentPhrases; } + + virtual std::string Debug() const; + + int m_lineNum; +protected: + Phrase m_source, m_target; + ConsistentPhrases m_consistentPhrases; + + void CreateConsistentPhrases(const Parameter ¶ms); + void PopulateWordVec(Phrase &vec, const std::string &line); + + // m_source and m_target MUST be populated before calling this + void PopulateAlignment(const std::string &line); + std::vector GetSourceAlignmentCount() const; +}; + + diff --git a/contrib/other-builds/extract-mixed-syntax/AlignedSentenceSyntax.cpp b/contrib/other-builds/extract-mixed-syntax/AlignedSentenceSyntax.cpp new file mode 100644 index 000000000..3d63ed044 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/AlignedSentenceSyntax.cpp @@ -0,0 +1,183 @@ +/* + * AlignedSentenceSyntax.cpp + * + * Created on: 26 Feb 2014 + * Author: hieu + */ + +#include "AlignedSentenceSyntax.h" +#include "Parameter.h" +#include "pugixml.hpp" +#include "moses/Util.h" + +using namespace std; + +AlignedSentenceSyntax::AlignedSentenceSyntax(int lineNum, + const std::string &source, + const std::string &target, + const std::string &alignment) +:AlignedSentence(lineNum) +,m_sourceStr(source) +,m_targetStr(target) +,m_alignmentStr(alignment) +{ +} + +AlignedSentenceSyntax::~AlignedSentenceSyntax() { + // TODO Auto-generated destructor stub +} + +void AlignedSentenceSyntax::Populate(bool isSyntax, int mixedSyntaxType, const Parameter ¶ms, + string line, Phrase &phrase, SyntaxTree &tree) +{ + // parse source and target string + if (isSyntax) { + line = "" + line + ""; + XMLParse(phrase, tree, line, params); + + if (mixedSyntaxType != 0) { + // mixed syntax. Always add [X] where there isn't 1 + tree.SetHieroLabel(params.hieroNonTerm); + if (mixedSyntaxType == 2) { + tree.AddToAll(params.hieroNonTerm); + } + } + } + else { + PopulateWordVec(phrase, line); + tree.SetHieroLabel(params.hieroNonTerm); + } + +} + +void AlignedSentenceSyntax::Create(const Parameter ¶ms) +{ + Populate(params.sourceSyntax, params.mixedSyntaxType, params, m_sourceStr, + m_source, m_sourceTree); + Populate(params.targetSyntax, params.mixedSyntaxType, params, m_targetStr, + m_target, m_targetTree); + + PopulateAlignment(m_alignmentStr); + CreateConsistentPhrases(params); + + // create labels + CreateNonTerms(); +} + +void Escape(string &text) +{ + text = Moses::Replace(text, "&", "&"); + text = Moses::Replace(text, "|", "|"); + text = Moses::Replace(text, "<", "<"); + text = Moses::Replace(text, ">", ">"); + text = Moses::Replace(text, "'", "'"); + text = Moses::Replace(text, "\"", """); + text = Moses::Replace(text, "[", "["); + text = Moses::Replace(text, "]", "]"); + +} + +void AlignedSentenceSyntax::XMLParse(Phrase &output, + SyntaxTree &tree, + const pugi::xml_node &parentNode, + const Parameter ¶ms) +{ + int childNum = 0; + for (pugi::xml_node childNode = parentNode.first_child(); childNode; childNode = childNode.next_sibling()) + { + string nodeName = childNode.name(); + + // span label + string label; + int startPos = output.size(); + + if (!nodeName.empty()) { + pugi::xml_attribute attribute = childNode.attribute("label"); + label = attribute.as_string(); + + // recursively call this function. For proper recursive trees + XMLParse(output, tree, childNode, params); + } + + + + // fill phrase vector + string text = childNode.value(); + Escape(text); + //cerr << childNum << " " << label << "=" << text << endl; + + std::vector toks; + Moses::Tokenize(toks, text); + + for (size_t i = 0; i < toks.size(); ++i) { + const string &tok = toks[i]; + Word *word = new Word(output.size(), tok); + output.push_back(word); + } + + // is it a labelled span? + int endPos = output.size() - 1; + + // fill syntax labels + if (!label.empty()) { + label = "[" + label + "]"; + tree.Add(startPos, endPos, label, params); + } + + ++childNum; + } + +} + +void AlignedSentenceSyntax::XMLParse(Phrase &output, + SyntaxTree &tree, + const std::string input, + const Parameter ¶ms) +{ + pugi::xml_document doc; + pugi::xml_parse_result result = doc.load(input.c_str(), + pugi::parse_default | pugi::parse_comments); + + pugi::xml_node topNode = doc.child("xml"); + XMLParse(output, tree, topNode, params); +} + +void AlignedSentenceSyntax::CreateNonTerms() +{ + for (int sourceStart = 0; sourceStart < m_source.size(); ++sourceStart) { + for (int sourceEnd = sourceStart; sourceEnd < m_source.size(); ++sourceEnd) { + ConsistentPhrases::Coll &coll = m_consistentPhrases.GetColl(sourceStart, sourceEnd); + const SyntaxTree::Labels &sourceLabels = m_sourceTree.Find(sourceStart, sourceEnd); + + ConsistentPhrases::Coll::iterator iter; + for (iter = coll.begin(); iter != coll.end(); ++iter) { + ConsistentPhrase &cp = **iter; + + int targetStart = cp.corners[2]; + int targetEnd = cp.corners[3]; + const SyntaxTree::Labels &targetLabels = m_targetTree.Find(targetStart, targetEnd); + + CreateNonTerms(cp, sourceLabels, targetLabels); + } + } + } + +} + +void AlignedSentenceSyntax::CreateNonTerms(ConsistentPhrase &cp, + const SyntaxTree::Labels &sourceLabels, + const SyntaxTree::Labels &targetLabels) +{ + SyntaxTree::Labels::const_iterator iterSource; + for (iterSource = sourceLabels.begin(); iterSource != sourceLabels.end(); ++iterSource) { + const string &sourceLabel = *iterSource; + + SyntaxTree::Labels::const_iterator iterTarget; + for (iterTarget = targetLabels.begin(); iterTarget != targetLabels.end(); ++iterTarget) { + const string &targetLabel = *iterTarget; + cp.AddNonTerms(sourceLabel, targetLabel); + } + } +} + + diff --git a/contrib/other-builds/extract-mixed-syntax/AlignedSentenceSyntax.h b/contrib/other-builds/extract-mixed-syntax/AlignedSentenceSyntax.h new file mode 100644 index 000000000..2e9431996 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/AlignedSentenceSyntax.h @@ -0,0 +1,46 @@ +/* + * AlignedSentenceSyntax.h + * + * Created on: 26 Feb 2014 + * Author: hieu + */ + +#pragma once + +#include "AlignedSentence.h" +#include "SyntaxTree.h" +#include "pugixml.hpp" + +class AlignedSentenceSyntax : public AlignedSentence +{ +public: + AlignedSentenceSyntax(int lineNum, + const std::string &source, + const std::string &target, + const std::string &alignment); + virtual ~AlignedSentenceSyntax(); + + void Create(const Parameter ¶ms); + + //virtual std::string Debug() const; +protected: + std::string m_sourceStr, m_targetStr, m_alignmentStr; + SyntaxTree m_sourceTree, m_targetTree; + + void XMLParse(Phrase &output, + SyntaxTree &tree, + const std::string input, + const Parameter ¶ms); + void XMLParse(Phrase &output, + SyntaxTree &tree, + const pugi::xml_node &parentNode, + const Parameter ¶ms); + void CreateNonTerms(); + void CreateNonTerms(ConsistentPhrase &cp, + const SyntaxTree::Labels &sourceLabels, + const SyntaxTree::Labels &targetLabels); + void Populate(bool isSyntax, int mixedSyntaxType, const Parameter ¶ms, + std::string line, Phrase &phrase, SyntaxTree &tree); + +}; + diff --git a/contrib/other-builds/extract-mixed-syntax/ConsistentPhrase.cpp b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrase.cpp new file mode 100644 index 000000000..bb913da5a --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrase.cpp @@ -0,0 +1,66 @@ +/* + * ConsistentPhrase.cpp + * + * Created on: 20 Feb 2014 + * Author: hieu + */ +#include +#include "ConsistentPhrase.h" +#include "Word.h" +#include "NonTerm.h" +#include "Parameter.h" + +using namespace std; + +ConsistentPhrase::ConsistentPhrase( + int sourceStart, int sourceEnd, + int targetStart, int targetEnd, + const Parameter ¶ms) +:corners(4) +,m_hieroNonTerm(*this, params.hieroNonTerm, params.hieroNonTerm) +{ + corners[0] = sourceStart; + corners[1] = sourceEnd; + corners[2] = targetStart; + corners[3] = targetEnd; +} + +ConsistentPhrase::~ConsistentPhrase() { + // TODO Auto-generated destructor stub +} + +bool ConsistentPhrase::operator<(const ConsistentPhrase &other) const +{ + return corners < other.corners; +} + +void ConsistentPhrase::AddNonTerms(const std::string &source, + const std::string &target) +{ + m_nonTerms.push_back(NonTerm(*this, source, target)); +} + +bool ConsistentPhrase::TargetOverlap(const ConsistentPhrase &other) const +{ + if ( other.corners[3] < corners[2] || other.corners[2] > corners[3]) + return false; + + return true; +} + +std::string ConsistentPhrase::Debug() const +{ + stringstream out; + out << "[" << corners[0] << "-" << corners[1] + << "][" << corners[2] << "-" << corners[3] << "]"; + + out << "NT:"; + for (size_t i = 0; i < m_nonTerms.size(); ++i) { + const NonTerm &nonTerm = m_nonTerms[i]; + out << nonTerm.GetLabel(Moses::Input) << ":" << nonTerm.GetLabel(Moses::Output); + } + + return out.str(); +} + + diff --git a/contrib/other-builds/extract-mixed-syntax/ConsistentPhrase.h b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrase.h new file mode 100644 index 000000000..865b4386f --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrase.h @@ -0,0 +1,51 @@ +/* + * ConsistentPhrase.h + * + * Created on: 20 Feb 2014 + * Author: hieu + */ + +#pragma once + +#include +#include +#include +#include "moses/TypeDef.h" +#include "NonTerm.h" + +class ConsistentPhrase +{ +public: + typedef std::vector NonTerms; + + std::vector corners; + + ConsistentPhrase(const ConsistentPhrase ©); // do not implement + ConsistentPhrase(int sourceStart, int sourceEnd, + int targetStart, int targetEnd, + const Parameter ¶ms); + + virtual ~ConsistentPhrase(); + + int GetWidth(Moses::FactorDirection direction) const + { return (direction == Moses::Input) ? corners[1] - corners[0] + 1 : corners[3] - corners[2] + 1; } + + + void AddNonTerms(const std::string &source, + const std::string &target); + const NonTerms &GetNonTerms() const + { return m_nonTerms;} + const NonTerm &GetHieroNonTerm() const + { return m_hieroNonTerm;} + + bool TargetOverlap(const ConsistentPhrase &other) const; + + bool operator<(const ConsistentPhrase &other) const; + + std::string Debug() const; + +protected: + NonTerms m_nonTerms; + NonTerm m_hieroNonTerm; +}; + diff --git a/contrib/other-builds/extract-mixed-syntax/ConsistentPhrases.cpp b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrases.cpp new file mode 100644 index 000000000..8978c88fa --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrases.cpp @@ -0,0 +1,103 @@ +/* + * ConsistentPhrases.cpp + * + * Created on: 20 Feb 2014 + * Author: hieu + */ +#include +#include +#include "ConsistentPhrases.h" +#include "NonTerm.h" +#include "Parameter.h" +#include "moses/Util.h" + +using namespace std; + +ConsistentPhrases::ConsistentPhrases() +{ +} + +ConsistentPhrases::~ConsistentPhrases() { + for (int start = 0; start < m_coll.size(); ++start) { + std::vector &allSourceStart = m_coll[start]; + + for (int size = 0; size < allSourceStart.size(); ++size) { + Coll &coll = allSourceStart[size]; + Moses::RemoveAllInColl(coll); + } + } +} + +void ConsistentPhrases::Initialize(size_t size) +{ + m_coll.resize(size); + + for (size_t sourceStart = 0; sourceStart < size; ++sourceStart) { + std::vector &allSourceStart = m_coll[sourceStart]; + allSourceStart.resize(size - sourceStart); + } +} + +void ConsistentPhrases::Add(int sourceStart, int sourceEnd, + int targetStart, int targetEnd, + const Parameter ¶ms) +{ + Coll &coll = m_coll[sourceStart][sourceEnd - sourceStart]; + ConsistentPhrase *cp = new ConsistentPhrase(sourceStart, sourceEnd, + targetStart, targetEnd, + params); + + pair inserted = coll.insert(cp); + assert(inserted.second); +} + +const ConsistentPhrases::Coll &ConsistentPhrases::GetColl(int sourceStart, int sourceEnd) const +{ + const std::vector &allSourceStart = m_coll[sourceStart]; + const Coll &ret = allSourceStart[sourceEnd - sourceStart]; + return ret; +} + +ConsistentPhrases::Coll &ConsistentPhrases::GetColl(int sourceStart, int sourceEnd) +{ + std::vector &allSourceStart = m_coll[sourceStart]; + Coll &ret = allSourceStart[sourceEnd - sourceStart]; + return ret; +} + +std::string ConsistentPhrases::Debug() const +{ + std::stringstream out; + for (int start = 0; start < m_coll.size(); ++start) { + const std::vector &allSourceStart = m_coll[start]; + + for (int size = 0; size < allSourceStart.size(); ++size) { + const Coll &coll = allSourceStart[size]; + + Coll::const_iterator iter; + for (iter = coll.begin(); iter != coll.end(); ++iter) { + const ConsistentPhrase &consistentPhrase = **iter; + out << consistentPhrase.Debug() << endl; + } + } + } + + return out.str(); +} + +void ConsistentPhrases::AddHieroNonTerms(const Parameter ¶ms) +{ + // add [X] labels everywhere + for (int i = 0; i < m_coll.size(); ++i) { + vector &inner = m_coll[i]; + for (int j = 0; j < inner.size(); ++j) { + ConsistentPhrases::Coll &coll = inner[j]; + ConsistentPhrases::Coll::iterator iter; + for (iter = coll.begin(); iter != coll.end(); ++iter) { + ConsistentPhrase &cp = **iter; + cp.AddNonTerms(params.hieroNonTerm, params.hieroNonTerm); + } + } + } +} + diff --git a/contrib/other-builds/extract-mixed-syntax/ConsistentPhrases.h b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrases.h new file mode 100644 index 000000000..3daf6b7ff --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/ConsistentPhrases.h @@ -0,0 +1,40 @@ +/* + * ConsistentPhrases.h + * + * Created on: 20 Feb 2014 + * Author: hieu + */ +#pragma once + +#include +#include +#include +#include "ConsistentPhrase.h" + +class Word; +class Parameter; + +class ConsistentPhrases { +public: + typedef std::set Coll; + + ConsistentPhrases(); + virtual ~ConsistentPhrases(); + + void Initialize(size_t size); + + void Add(int sourceStart, int sourceEnd, + int targetStart, int targetEnd, + const Parameter ¶ms); + + void AddHieroNonTerms(const Parameter ¶ms); + + const Coll &GetColl(int sourceStart, int sourceEnd) const; + Coll &GetColl(int sourceStart, int sourceEnd); + + std::string Debug() const; + +protected: + std::vector< std::vector > m_coll; +}; + diff --git a/contrib/other-builds/extract-mixed-syntax/Global.cpp b/contrib/other-builds/extract-mixed-syntax/Global.cpp deleted file mode 100644 index 27aeb4b95..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Global.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Global.cpp - * extract - * - * Created by Hieu Hoang on 01/02/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ - -#include "Global.h" - -bool g_debug = false; - -Global::Global() -: minHoleSpanSourceDefault(2) -, maxHoleSpanSourceDefault(7) -, minHoleSpanSourceSyntax(1) -, maxHoleSpanSourceSyntax(1000) -, maxUnaligned(5) - -, maxSymbols(5) -, maxNonTerm(3) -, maxNonTermDefault(2) - -// int minHoleSize(1) -// int minSubPhraseSize(1) // minimum size of a remaining lexical phrase -, glueGrammarFlag(false) -, unknownWordLabelFlag(false) -//bool zipFiles(false) -, sourceSyntax(true) -, targetSyntax(false) -, mixed(true) -, uppermostOnly(true) -, allowDefaultNonTermEdge(true) -, gzOutput(false) - -{} diff --git a/contrib/other-builds/extract-mixed-syntax/Global.h b/contrib/other-builds/extract-mixed-syntax/Global.h deleted file mode 100644 index 41cdbf0ce..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Global.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once -/* - * Global.h - * extract - * - * Created by Hieu Hoang on 01/02/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include - -class Global -{ -public: - int minHoleSpanSourceDefault; - int maxHoleSpanSourceDefault; - int minHoleSpanSourceSyntax; - int maxHoleSpanSourceSyntax; - - int maxSymbols; - bool glueGrammarFlag; - bool unknownWordLabelFlag; - int maxNonTerm; - int maxNonTermDefault; - bool sourceSyntax; - bool targetSyntax; - bool mixed; - int maxUnaligned; - bool uppermostOnly; - bool allowDefaultNonTermEdge; - bool gzOutput; - - Global(); - - Global(const Global&); - -}; - -extern bool g_debug; - -#define DEBUG_OUTPUT() void DebugOutput() const; - - diff --git a/contrib/other-builds/extract-mixed-syntax/Lattice.cpp b/contrib/other-builds/extract-mixed-syntax/Lattice.cpp deleted file mode 100644 index 2b9ebac6e..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Lattice.cpp +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Lattice.cpp - * extract - * - * Created by Hieu Hoang on 18/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ - -#include -#include "Lattice.h" -#include "LatticeNode.h" -#include "Tunnel.h" -#include "TunnelCollection.h" -#include "SyntaxTree.h" -#include "SentenceAlignment.h" -#include "tables-core.h" -#include "Rule.h" -#include "RuleCollection.h" - -using namespace std; - -Lattice::Lattice(size_t sourceSize) -:m_stacks(sourceSize + 1) -{ -} - -Lattice::~Lattice() -{ - std::vector::iterator iterStack; - for (iterStack = m_stacks.begin(); iterStack != m_stacks.end(); ++iterStack) - { - Stack &stack = *iterStack; - RemoveAllInColl(stack); - } -} - -void Lattice::CreateArcs(size_t startPos, const TunnelCollection &tunnelColl, const SentenceAlignment &sentence, const Global &global) -{ - // term - Stack &startStack = GetStack(startPos); - - LatticeNode *node = new LatticeNode(startPos, &sentence); - startStack.push_back(node); - - // non-term - for (size_t endPos = startPos + 1; endPos <= sentence.source.size(); ++endPos) - { - const TunnelList &tunnels = tunnelColl.GetTunnels(startPos, endPos - 1); - - TunnelList::const_iterator iterHole; - for (iterHole = tunnels.begin(); iterHole != tunnels.end(); ++iterHole) - { - const Tunnel &tunnel = *iterHole; - CreateArcsUsing1Hole(tunnel, sentence, global); - } - } -} - -void Lattice::CreateArcsUsing1Hole(const Tunnel &tunnel, const SentenceAlignment &sentence, const Global &global) -{ - size_t startPos = tunnel.GetRange(0).GetStartPos() - , endPos = tunnel.GetRange(0).GetEndPos(); - size_t numSymbols = tunnel.GetRange(0).GetWidth(); - assert(numSymbols > 0); - - Stack &startStack = GetStack(startPos); - - - // non-terms. cartesian product of source & target labels - assert(startPos == tunnel.GetRange(0).GetStartPos() && endPos == tunnel.GetRange(0).GetEndPos()); - size_t startT = tunnel.GetRange(1).GetStartPos() - ,endT = tunnel.GetRange(1).GetEndPos(); - - const SyntaxNodes &nodesS = sentence.sourceTree.GetNodes(startPos, endPos); - const SyntaxNodes &nodesT = sentence.targetTree.GetNodes(startT, endT ); - - SyntaxNodes::const_iterator iterS, iterT; - for (iterS = nodesS.begin(); iterS != nodesS.end(); ++iterS) - { - const SyntaxNode *syntaxNodeS = *iterS; - - for (iterT = nodesT.begin(); iterT != nodesT.end(); ++iterT) - { - const SyntaxNode *syntaxNodeT = *iterT; - - bool isSyntax = syntaxNodeS->IsSyntax() || syntaxNodeT->IsSyntax(); - size_t maxSourceNonTermSpan = isSyntax ? global.maxHoleSpanSourceSyntax : global.maxHoleSpanSourceDefault; - - if (maxSourceNonTermSpan >= endPos - startPos) - { - LatticeNode *node = new LatticeNode(tunnel, syntaxNodeS, syntaxNodeT); - startStack.push_back(node); - } - } - } -} - -Stack &Lattice::GetStack(size_t startPos) -{ - assert(startPos < m_stacks.size()); - return m_stacks[startPos]; -} - -const Stack &Lattice::GetStack(size_t startPos) const -{ - assert(startPos < m_stacks.size()); - return m_stacks[startPos]; -} - -void Lattice::CreateRules(size_t startPos, const SentenceAlignment &sentence, const Global &global) -{ - const Stack &startStack = GetStack(startPos); - - Stack::const_iterator iterStack; - for (iterStack = startStack.begin(); iterStack != startStack.end(); ++iterStack) - { - const LatticeNode *node = *iterStack; - Rule *initRule = new Rule(node); - - if (initRule->CanRecurse(global, sentence.GetTunnelCollection())) - { // may or maynot be valid, but can continue to build on this rule - initRule->CreateRules(m_rules, *this, sentence, global); - } - - if (initRule->IsValid(global, sentence.GetTunnelCollection())) - { // add to rule collection - m_rules.Add(global, initRule, sentence); - } - else - { - delete initRule; - } - - - } -} - -Stack Lattice::GetNonTermNode(const Range &sourceRange) const -{ - Stack ret; - size_t sourcePos = sourceRange.GetStartPos(); - - const Stack &origStack = GetStack(sourcePos); - Stack::const_iterator iter; - for (iter = origStack.begin(); iter != origStack.end(); ++iter) - { - LatticeNode *node = *iter; - const Range &nodeRangeS = node->GetSourceRange(); - - assert(nodeRangeS.GetStartPos() == sourceRange.GetStartPos()); - - if (! node->IsTerminal() && nodeRangeS.GetEndPos() == sourceRange.GetEndPos()) - { - ret.push_back(node); - } - } - - return ret; -} - -std::ostream& operator<<(std::ostream &out, const Lattice &obj) -{ - std::vector::const_iterator iter; - for (iter = obj.m_stacks.begin(); iter != obj.m_stacks.end(); ++iter) - { - const Stack &stack = *iter; - - Stack::const_iterator iterStack; - for (iterStack = stack.begin(); iterStack != stack.end(); ++iterStack) - { - const LatticeNode &node = **iterStack; - out << node << " "; - } - } - - return out; -} - - diff --git a/contrib/other-builds/extract-mixed-syntax/Lattice.h b/contrib/other-builds/extract-mixed-syntax/Lattice.h deleted file mode 100644 index c88aa0844..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Lattice.h +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once -/* - * Lattice.h - * extract - * - * Created by Hieu Hoang on 18/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include "RuleCollection.h" - -class Global; -class LatticeNode; -class Tunnel; -class TunnelCollection; -class SentenceAlignment; - -typedef std::vector Stack; - -class Lattice -{ - friend std::ostream& operator<<(std::ostream&, const Lattice&); - - std::vector m_stacks; - RuleCollection m_rules; - - Stack &GetStack(size_t endPos); - - void CreateArcsUsing1Hole(const Tunnel &tunnel, const SentenceAlignment &sentence, const Global &global); - -public: - Lattice(size_t sourceSize); - ~Lattice(); - - void CreateArcs(size_t startPos, const TunnelCollection &tunnelColl, const SentenceAlignment &sentence, const Global &global); - void CreateRules(size_t startPos, const SentenceAlignment &sentence, const Global &global); - - const Stack &GetStack(size_t startPos) const; - const RuleCollection &GetRules() const - { return m_rules; } - - Stack GetNonTermNode(const Range &sourceRange) const; - -}; - diff --git a/contrib/other-builds/extract-mixed-syntax/LatticeNode.cpp b/contrib/other-builds/extract-mixed-syntax/LatticeNode.cpp deleted file mode 100644 index 8f0cbfc0f..000000000 --- a/contrib/other-builds/extract-mixed-syntax/LatticeNode.cpp +++ /dev/null @@ -1,149 +0,0 @@ -/* - * LatticeNode.cpp - * extract - * - * Created by Hieu Hoang on 18/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include "LatticeNode.h" -#include "SyntaxTree.h" -#include "Tunnel.h" -#include "SentenceAlignment.h" -#include "SymbolSequence.h" - -size_t LatticeNode::s_count = 0; - -using namespace std; - -// for terms -LatticeNode::LatticeNode(size_t pos, const SentenceAlignment *sentence) -:m_tunnel(NULL) -,m_isTerminal(true) -,m_sourceTreeNode(NULL) -,m_targetTreeNode(NULL) -,m_sentence(sentence) -,m_sourceRange(pos, pos) -{ - s_count++; - //cerr << *this << endl; -} - -// for non-terms -LatticeNode::LatticeNode(const Tunnel &tunnel, const SyntaxNode *sourceTreeNode, const SyntaxNode *targetTreeNode) -:m_tunnel(&tunnel) -,m_isTerminal(false) -,m_sourceTreeNode(sourceTreeNode) -,m_targetTreeNode(targetTreeNode) -,m_sentence(NULL) -,m_sourceRange(tunnel.GetRange(0)) -{ - s_count++; - //cerr << *this << endl; -} - -bool LatticeNode::IsSyntax() const -{ - assert(!m_isTerminal); - bool ret = m_sourceTreeNode->IsSyntax() || m_targetTreeNode->IsSyntax(); - return ret; -} - -size_t LatticeNode::GetNumSymbols(size_t direction) const -{ - return 1; -} - -int LatticeNode::Compare(const LatticeNode &otherNode) const -{ - int ret = 0; - if (m_isTerminal != otherNode.m_isTerminal) - { - ret = m_isTerminal ? -1 : 1; - } - - // both term or non-term - else if (m_isTerminal) - { // term. compare source span - if (m_sourceRange.GetStartPos() == otherNode.m_sourceRange.GetStartPos()) - ret = 0; - else - ret = (m_sourceRange.GetStartPos() < otherNode.m_sourceRange.GetStartPos()) ? -1 : +1; - } - else - { // non-term. compare source span and BOTH label - assert(!m_isTerminal); - assert(!otherNode.m_isTerminal); - - if (m_sourceTreeNode->IsSyntax()) - { - ret = m_tunnel->Compare(*otherNode.m_tunnel, 0); - if (ret == 0 && m_sourceTreeNode->GetLabel() != otherNode.m_sourceTreeNode->GetLabel()) - { - ret = (m_sourceTreeNode->GetLabel() < otherNode.m_sourceTreeNode->GetLabel()) ? -1 : +1; - } - } - - if (ret == 0 && m_targetTreeNode->IsSyntax()) - { - ret = m_tunnel->Compare(*otherNode.m_tunnel, 1); - if (ret == 0 && m_targetTreeNode->GetLabel() != otherNode.m_targetTreeNode->GetLabel()) - { - ret = (m_targetTreeNode->GetLabel() < otherNode.m_targetTreeNode->GetLabel()) ? -1 : +1; - } - } - } - - return ret; -} - -void LatticeNode::CreateSymbols(size_t direction, SymbolSequence &symbols) const -{ - if (m_isTerminal) - { - /* - const std::vector &words = (direction == 0 ? m_sentence->source : m_sentence->target); - size_t startPos = m_tunnel.GetStart(direction) - ,endPos = m_tunnel.GetEnd(direction); - - for (size_t pos = startPos; pos <= endPos; ++pos) - { - Symbol symbol(words[pos], pos); - symbols.Add(symbol); - } - */ - } - else - { // output both - - Symbol symbol(m_sourceTreeNode->GetLabel(), m_targetTreeNode->GetLabel() - , m_tunnel->GetRange(0).GetStartPos(), m_tunnel->GetRange(0).GetEndPos() - , m_tunnel->GetRange(1).GetStartPos(), m_tunnel->GetRange(1).GetEndPos() - , m_sourceTreeNode->IsSyntax(), m_targetTreeNode->IsSyntax()); - - symbols.Add(symbol); - } - -} - -std::ostream& operator<<(std::ostream &out, const LatticeNode &obj) -{ - if (obj.m_isTerminal) - { - assert(obj.m_sourceRange.GetWidth() == 1); - size_t pos = obj.m_sourceRange.GetStartPos(); - - const SentenceAlignment &sentence = *obj.m_sentence; - out << obj.m_sourceRange << "=" << sentence.source[pos]; - } - else - { - assert(obj.m_tunnel); - out << obj.GetTunnel() << "=" << obj.m_sourceTreeNode->GetLabel() << obj.m_targetTreeNode->GetLabel() << " "; - } - - return out; -} - - diff --git a/contrib/other-builds/extract-mixed-syntax/LatticeNode.h b/contrib/other-builds/extract-mixed-syntax/LatticeNode.h deleted file mode 100644 index 73ea6a224..000000000 --- a/contrib/other-builds/extract-mixed-syntax/LatticeNode.h +++ /dev/null @@ -1,77 +0,0 @@ -#pragma once -/* - * LatticeNode.h - * extract - * - * Created by Hieu Hoang on 18/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include -#include "Range.h" - -class Tunnel; -class SyntaxNode; -class SentenceAlignment; -class SymbolSequence; - -class LatticeNode -{ - friend std::ostream& operator<<(std::ostream&, const LatticeNode&); - - bool m_isTerminal; - - // for terms & non-term - Range m_sourceRange; - - // non-terms. source range should be same as m_sourceRange - const Tunnel *m_tunnel; - -public: - static size_t s_count; - - - - const SyntaxNode *m_sourceTreeNode, *m_targetTreeNode; - const SentenceAlignment *m_sentence; - - // for terms - LatticeNode(size_t pos, const SentenceAlignment *sentence); - - // for non-terms - LatticeNode(const Tunnel &tunnel, const SyntaxNode *sourceTreeNode, const SyntaxNode *targetTreeNode); - - bool IsTerminal() const - { return m_isTerminal; } - - bool IsSyntax() const; - - size_t GetNumSymbols(size_t direction) const; - - std::string ToString() const; - - int Compare(const LatticeNode &otherNode) const; - - void CreateSymbols(size_t direction, SymbolSequence &symbols) const; - - const Tunnel &GetTunnel() const - { - assert(m_tunnel); - return *m_tunnel; - } - - const Range &GetSourceRange() const - { - return m_sourceRange; - } - const SyntaxNode &GetSyntaxNode(size_t direction) const - { - const SyntaxNode *node = direction == 0 ? m_sourceTreeNode : m_targetTreeNode; - assert(node); - return *node; - } - -}; - diff --git a/contrib/other-builds/extract-mixed-syntax/Main.cpp b/contrib/other-builds/extract-mixed-syntax/Main.cpp new file mode 100644 index 000000000..89875daa9 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Main.cpp @@ -0,0 +1,174 @@ +#include +#include +#include + +#include "Main.h" +#include "InputFileStream.h" +#include "OutputFileStream.h" +#include "AlignedSentence.h" +#include "AlignedSentenceSyntax.h" +#include "Parameter.h" +#include "Rules.h" + +using namespace std; + +bool g_debug = false; + +int main(int argc, char** argv) +{ + cerr << "Starting" << endl; + + Parameter params; + + namespace po = boost::program_options; + po::options_description desc("Options"); + desc.add_options() + ("help", "Print help messages") + ("MaxSpan", po::value()->default_value(params.maxSpan), "Max (source) span of a rule. ie. number of words in the source") + ("GlueGrammar", po::value()->default_value(params.gluePath), "Output glue grammar to here") + ("SentenceOffset", po::value()->default_value(params.sentenceOffset), "Starting sentence id. Not used") + ("GZOutput", "Compress extract files") + ("MaxNonTerm", po::value()->default_value(params.maxNonTerm), "Maximum number of non-terms allowed per rule") + ("MaxHieroNonTerm", po::value()->default_value(params.maxHieroNonTerm), "Maximum number of Hiero non-term. Usually, --MaxNonTerm is the normal constraint") + ("MinHoleSource", po::value()->default_value(params.minHoleSource), "Minimum source span for a non-term.") + + ("SourceSyntax", "Source sentence is a parse tree") + ("TargetSyntax", "Target sentence is a parse tree") + ("MixedSyntaxType", po::value()->default_value(params.mixedSyntaxType), "Hieu's Mixed syntax type. 0(default)=no mixed syntax, 1=add [X] only if no syntactic label. 2=add [X] everywhere") + ("MultiLabel", po::value()->default_value(params.multiLabel), "What to do with multiple labels on the same span. 0(default)=keep them all, 1=keep only top-most, 2=keep only bottom-most") + ("HieroSourceLHS", "Always use Hiero source LHS? Default = 0") + ("MaxSpanFreeNonTermSource", po::value()->default_value(params.maxSpanFreeNonTermSource), "Max number of words covered by beginning/end NT. Default = 0 (no limit)") + ("NoNieceTerminal", "Don't extract rule if 1 of the non-term covers the same word as 1 of the terminals") + ("MaxScope", po::value()->default_value(params.maxScope), "maximum scope (see Hopkins and Langmead (2010)). Default is HIGH") + ("SpanLength", "Property - span length of RHS each non-term") + ("NonTermContext", "Property - left and right, inside and outside words of each non-term"); + + po::variables_map vm; + try + { + po::store(po::parse_command_line(argc, argv, desc), + vm); // can throw + + /** --help option + */ + if ( vm.count("help") || argc < 5 ) + { + std::cout << argv[0] << " target source alignment [options...]" << std::endl + << desc << std::endl; + return EXIT_SUCCESS; + } + + po::notify(vm); // throws on error, so do after help in case + // there are any problems + } + catch(po::error& e) + { + std::cerr << "ERROR: " << e.what() << std::endl << std::endl; + std::cerr << desc << std::endl; + return EXIT_FAILURE; + } + + if (vm.count("MaxSpan")) params.maxSpan = vm["MaxSpan"].as(); + if (vm.count("GZOutput")) params.gzOutput = true; + if (vm.count("GlueGrammar")) params.gluePath = vm["GlueGrammar"].as(); + if (vm.count("SentenceOffset")) params.sentenceOffset = vm["SentenceOffset"].as(); + if (vm.count("MaxNonTerm")) params.maxNonTerm = vm["MaxNonTerm"].as(); + if (vm.count("MaxHieroNonTerm")) params.maxHieroNonTerm = vm["MaxHieroNonTerm"].as(); + if (vm.count("MinHoleSource")) params.minHoleSource = vm["MinHoleSource"].as(); + + if (vm.count("SourceSyntax")) params.sourceSyntax = true; + if (vm.count("TargetSyntax")) params.targetSyntax = true; + if (vm.count("MixedSyntaxType")) params.mixedSyntaxType = vm["MixedSyntaxType"].as(); + if (vm.count("MultiLabel")) params.multiLabel = vm["MultiLabel"].as(); + if (vm.count("HieroSourceLHS")) params.hieroSourceLHS = true; + if (vm.count("MaxSpanFreeNonTermSource")) params.maxSpanFreeNonTermSource = vm["MaxSpanFreeNonTermSource"].as(); + if (vm.count("NoNieceTerminal")) params.nieceTerminal = false; + if (vm.count("MaxScope")) params.maxScope = vm["MaxScope"].as(); + + // properties + if (vm.count("SpanLength")) params.spanLength = true; + if (vm.count("NonTermContext")) params.nonTermContext = true; + + // input files; + string pathTarget = argv[1]; + string pathSource = argv[2]; + string pathAlignment = argv[3]; + + string pathExtract = argv[4]; + string pathExtractInv = pathExtract + ".inv"; + if (params.gzOutput) { + pathExtract += ".gz"; + pathExtractInv += ".gz"; + } + + Moses::InputFileStream strmTarget(pathTarget); + Moses::InputFileStream strmSource(pathSource); + Moses::InputFileStream strmAlignment(pathAlignment); + Moses::OutputFileStream extractFile(pathExtract); + Moses::OutputFileStream extractInvFile(pathExtractInv); + + + // MAIN LOOP + int lineNum = 1; + string lineTarget, lineSource, lineAlignment; + while (getline(strmTarget, lineTarget)) { + if (lineNum % 10000 == 0) { + cerr << lineNum << " "; + } + + bool success; + success = getline(strmSource, lineSource); + if (!success) { + throw "Couldn't read source"; + } + success = getline(strmAlignment, lineAlignment); + if (!success) { + throw "Couldn't read alignment"; + } + + /* + cerr << "lineTarget=" << lineTarget << endl; + cerr << "lineSource=" << lineSource << endl; + cerr << "lineAlignment=" << lineAlignment << endl; + */ + + AlignedSentence *alignedSentence; + + if (params.sourceSyntax || params.targetSyntax) { + alignedSentence = new AlignedSentenceSyntax(lineNum, lineSource, lineTarget, lineAlignment); + } + else { + alignedSentence = new AlignedSentence(lineNum, lineSource, lineTarget, lineAlignment); + } + + alignedSentence->Create(params); + //cerr << alignedSentence->Debug(); + + Rules rules(*alignedSentence); + rules.Extend(params); + rules.Consolidate(params); + //cerr << rules.Debug(); + + rules.Output(extractFile, true, params); + rules.Output(extractInvFile, false, params); + + delete alignedSentence; + + ++lineNum; + } + + if (!params.gluePath.empty()) { + Moses::OutputFileStream glueFile(params.gluePath); + CreateGlueGrammar(glueFile); + } + + cerr << "Finished" << endl; +} + +void CreateGlueGrammar(Moses::OutputFileStream &glueFile) +{ + glueFile << " [X] ||| [S] ||| 1 ||| ||| 0" << endl + << "[X][S] [X] ||| [X][S] [S] ||| 1 ||| 0-0 ||| 0" << endl + << "[X][S] [X][X] [X] ||| [X][S] [X][X] [S] ||| 2.718 ||| 0-0 1-1 ||| 0" << endl; + +} diff --git a/contrib/other-builds/extract-mixed-syntax/Main.h b/contrib/other-builds/extract-mixed-syntax/Main.h new file mode 100644 index 000000000..9744ba389 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Main.h @@ -0,0 +1,12 @@ +/* + * Main.h + * + * Created on: 28 Feb 2014 + * Author: hieu + */ +#pragma once + +#include "OutputFileStream.h" + +void CreateGlueGrammar(Moses::OutputFileStream &glueFile); + diff --git a/contrib/other-builds/extract-mixed-syntax/Makefile b/contrib/other-builds/extract-mixed-syntax/Makefile index b992b161f..f612b8667 100644 --- a/contrib/other-builds/extract-mixed-syntax/Makefile +++ b/contrib/other-builds/extract-mixed-syntax/Makefile @@ -1,13 +1,17 @@ -all: extract +all: extract-mixed-syntax clean: rm -f *.o extract-mixed-syntax .cpp.o: - g++ -O6 -g -c $< + g++ -O4 -g -c -I../../../boost/include -I../../../ $< -extract: tables-core.o extract.o SyntaxTree.o XmlTree.o Tunnel.o Lattice.o LatticeNode.o SentenceAlignment.o Global.o InputFileStream.o TunnelCollection.o RuleCollection.o Rule.o Symbol.o SymbolSequence.o Range.o OutputFileStream.o +OBJECTS = AlignedSentence.o ConsistentPhrase.o ConsistentPhrases.o InputFileStream.o \ + Main.o OutputFileStream.o Parameter.o Phrase.o Rule.o Rules.o RuleSymbol.o \ + SyntaxTree.o Word.o NonTerm.o RulePhrase.o AlignedSentenceSyntax.o pugixml.o - g++ tables-core.o extract.o SyntaxTree.o XmlTree.o Tunnel.o Lattice.o LatticeNode.o SentenceAlignment.o Global.o InputFileStream.o TunnelCollection.o RuleCollection.o Rule.o Symbol.o SymbolSequence.o Range.o OutputFileStream.o -lz -lboost_iostreams-mt -o extract-mixed-syntax +extract-mixed-syntax: $(OBJECTS) + + g++ $(OBJECTS) -L../../../boost/lib64 -lz -lboost_iostreams-mt -lboost_program_options-mt -o extract-mixed-syntax diff --git a/contrib/other-builds/extract-mixed-syntax/NonTerm.cpp b/contrib/other-builds/extract-mixed-syntax/NonTerm.cpp new file mode 100644 index 000000000..9e7d0dcaa --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/NonTerm.cpp @@ -0,0 +1,65 @@ +/* + * NonTerm.cpp + * + * Created on: 22 Feb 2014 + * Author: hieu + */ + +#include +#include "NonTerm.h" +#include "Word.h" +#include "ConsistentPhrase.h" +#include "Parameter.h" + +using namespace std; + +NonTerm::NonTerm(const ConsistentPhrase &consistentPhrase, + const std::string &source, + const std::string &target) +:m_consistentPhrase(&consistentPhrase) +,m_source(source) +,m_target(target) +{ + // TODO Auto-generated constructor stub + +} + +NonTerm::~NonTerm() { + // TODO Auto-generated destructor stub +} + +std::string NonTerm::Debug() const +{ + stringstream out; + out << m_source << m_target; + out << m_consistentPhrase->Debug(); + return out.str(); +} + +void NonTerm::Output(std::ostream &out) const +{ + out << m_source << m_target; +} + +void NonTerm::Output(std::ostream &out, Moses::FactorDirection direction) const +{ + out << GetLabel(direction); +} + +const std::string &NonTerm::GetLabel(Moses::FactorDirection direction) const +{ + return (direction == Moses::Input) ? m_source : m_target; +} + +bool NonTerm::IsHiero(Moses::FactorDirection direction, const Parameter ¶ms) const +{ + const std::string &label = NonTerm::GetLabel(direction); + return label == params.hieroNonTerm; +} + +bool NonTerm::IsHiero(const Parameter ¶ms) const +{ + return IsHiero(Moses::Input, params) && IsHiero(Moses::Output, params); +} +int NonTerm::GetWidth(Moses::FactorDirection direction) const +{ return GetConsistentPhrase().GetWidth(direction); } diff --git a/contrib/other-builds/extract-mixed-syntax/NonTerm.h b/contrib/other-builds/extract-mixed-syntax/NonTerm.h new file mode 100644 index 000000000..5b3bb9f04 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/NonTerm.h @@ -0,0 +1,47 @@ +/* + * NonTerm.h + * + * Created on: 22 Feb 2014 + * Author: hieu + */ +#pragma once +#include +#include "RuleSymbol.h" +#include "moses/TypeDef.h" + +class ConsistentPhrase; +class Parameter; + +class NonTerm : public RuleSymbol +{ +public: + + NonTerm(const ConsistentPhrase &consistentPhrase, + const std::string &source, + const std::string &target); + virtual ~NonTerm(); + + const ConsistentPhrase &GetConsistentPhrase() const + { return *m_consistentPhrase; } + + int GetWidth(Moses::FactorDirection direction) const; + + virtual bool IsNonTerm() const + { return true; } + + std::string GetString() const + { return m_source + m_target; } + + virtual std::string Debug() const; + virtual void Output(std::ostream &out) const; + void Output(std::ostream &out, Moses::FactorDirection direction) const; + + const std::string &GetLabel(Moses::FactorDirection direction) const; + bool IsHiero(Moses::FactorDirection direction, const Parameter ¶ms) const; + bool IsHiero(const Parameter ¶ms) const; + +protected: + const ConsistentPhrase *m_consistentPhrase; + std::string m_source, m_target; +}; + diff --git a/contrib/other-builds/extract-mixed-syntax/Parameter.cpp b/contrib/other-builds/extract-mixed-syntax/Parameter.cpp new file mode 100644 index 000000000..f22116638 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Parameter.cpp @@ -0,0 +1,41 @@ +/* + * Parameter.cpp + * + * Created on: 17 Feb 2014 + * Author: hieu + */ +#include "Parameter.h" + +Parameter::Parameter() +:maxSpan(10) +,maxNonTerm(2) +,maxHieroNonTerm(999) +,maxSymbolsTarget(999) +,maxSymbolsSource(5) +,minHoleSource(2) +,sentenceOffset(0) +,nonTermConsecSource(false) +,requireAlignedWord(true) +,fractionalCounting(true) +,gzOutput(false) + +,hieroNonTerm("[X]") +,sourceSyntax(false) +,targetSyntax(false) + +,mixedSyntaxType(0) +,multiLabel(0) +,nonTermConsecSourceMixed(true) +,hieroSourceLHS(false) +,maxSpanFreeNonTermSource(0) +,nieceTerminal(true) +,maxScope(UNDEFINED) + +,spanLength(false) +,nonTermContext(false) +{} + +Parameter::~Parameter() { + // TODO Auto-generated destructor stub +} + diff --git a/contrib/other-builds/extract-mixed-syntax/Parameter.h b/contrib/other-builds/extract-mixed-syntax/Parameter.h new file mode 100644 index 000000000..1da090c86 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Parameter.h @@ -0,0 +1,51 @@ +/* + * Parameter.h + * + * Created on: 17 Feb 2014 + * Author: hieu + */ +#pragma once + +#include +#include + +#define UNDEFINED std::numeric_limits::max() + +class Parameter +{ +public: + Parameter(); + virtual ~Parameter(); + + int maxSpan; + int maxNonTerm; + int maxHieroNonTerm; + int maxSymbolsTarget; + int maxSymbolsSource; + int minHoleSource; + + long sentenceOffset; + + bool nonTermConsecSource; + bool requireAlignedWord; + bool fractionalCounting; + bool gzOutput; + + std::string hieroNonTerm; + std::string gluePath; + + bool sourceSyntax, targetSyntax; + + int mixedSyntaxType, multiLabel; + bool nonTermConsecSourceMixed; + bool hieroSourceLHS; + int maxSpanFreeNonTermSource; + bool nieceTerminal; + int maxScope; + + // prperties + bool spanLength; + bool nonTermContext; + +}; + diff --git a/contrib/other-builds/extract-mixed-syntax/Phrase.cpp b/contrib/other-builds/extract-mixed-syntax/Phrase.cpp new file mode 100644 index 000000000..535e10d6b --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Phrase.cpp @@ -0,0 +1,14 @@ +#include +#include "Phrase.h" + +std::string Phrase::Debug() const +{ + std::stringstream out; + + for (size_t i = 0; i < size(); ++i) { + Word &word = *at(i); + out << word.Debug() << " "; + } + + return out.str(); +} diff --git a/contrib/other-builds/extract-mixed-syntax/Phrase.h b/contrib/other-builds/extract-mixed-syntax/Phrase.h new file mode 100644 index 000000000..13912cb95 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Phrase.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include "Word.h" + +// a vector of terminals +class Phrase : public std::vector +{ +public: + Phrase() + {} + + Phrase(size_t size) + :std::vector(size) + {} + + std::string Debug() const; + +}; diff --git a/contrib/other-builds/extract-mixed-syntax/Range.cpp b/contrib/other-builds/extract-mixed-syntax/Range.cpp deleted file mode 100644 index a98ac278b..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Range.cpp +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Range.cpp - * extract - * - * Created by Hieu Hoang on 22/02/2011. - * Copyright 2011 __MyCompanyName__. All rights reserved. - * - */ - -#include "Range.h" - -using namespace std; - -void Range::Merge(const Range &a, const Range &b) -{ - if (a.m_startPos == NOT_FOUND) - { // get the other regardless - m_startPos = b.m_startPos; - } - else if (b.m_startPos == NOT_FOUND) - { - m_startPos = a.m_startPos; - } - else - { - m_startPos = min(a.m_startPos, b.m_startPos); - } - - if (a.m_endPos == NOT_FOUND) - { // get the other regardless - m_endPos = b.m_endPos; - } - else if (b.m_endPos == NOT_FOUND) - { // do nothing - m_endPos = a.m_endPos; - } - else - { - m_endPos = max(a.m_endPos, b.m_endPos); - } - - -} - -int Range::Compare(const Range &other) const -{ - if (m_startPos < other.m_startPos) - return -1; - else if (m_startPos > other.m_startPos) - return +1; - else if (m_endPos < other.m_endPos) - return -1; - else if (m_endPos > other.m_endPos) - return +1; - - return 0; - -} - -bool Range::Overlap(const Range &other) const -{ - if ( other.m_endPos < m_startPos || other.m_startPos > m_endPos) - return false; - - return true; -} - -std::ostream& operator<<(std::ostream &out, const Range &range) -{ - out << "[" << range.m_startPos << "-" << range.m_endPos << "]"; - return out; -} - - diff --git a/contrib/other-builds/extract-mixed-syntax/Range.h b/contrib/other-builds/extract-mixed-syntax/Range.h deleted file mode 100644 index 05d0c97c9..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Range.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Range.h - * extract - * - * Created by Hieu Hoang on 22/02/2011. - * Copyright 2011 __MyCompanyName__. All rights reserved. - * - */ -#pragma once -#include -#include -#include - -#define NOT_FOUND std::numeric_limits::max() - -class Range -{ - friend std::ostream& operator<<(std::ostream&, const Range&); - - size_t m_startPos, m_endPos; -public: - - Range() - :m_startPos(NOT_FOUND) - ,m_endPos(NOT_FOUND) - {} - - Range(const Range ©) - :m_startPos(copy.m_startPos) - ,m_endPos(copy.m_endPos) - {} - - Range(size_t startPos, size_t endPos) - :m_startPos(startPos) - ,m_endPos(endPos) - {} - - size_t GetStartPos() const - { return m_startPos; } - size_t GetEndPos() const - { return m_endPos; } - size_t GetWidth() const - { return m_endPos - m_startPos + 1; } - - void SetStartPos(size_t startPos) - { m_startPos = startPos; } - void SetEndPos(size_t endPos) - { m_endPos = endPos; } - - void Merge(const Range &a, const Range &b); - - int Compare(const Range &other) const; - - bool Overlap(const Range &other) const; - - -}; diff --git a/contrib/other-builds/extract-mixed-syntax/Rule.cpp b/contrib/other-builds/extract-mixed-syntax/Rule.cpp index 7cc7d3a6f..c16d0f8c4 100644 --- a/contrib/other-builds/extract-mixed-syntax/Rule.cpp +++ b/contrib/other-builds/extract-mixed-syntax/Rule.cpp @@ -1,594 +1,540 @@ /* - * Rule.cpp - * extract - * - * Created by Hieu Hoang on 19/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. + * Rule.cpp * + * Created on: 20 Feb 2014 + * Author: hieu */ -#include + #include +#include #include "Rule.h" -#include "Global.h" -#include "LatticeNode.h" -#include "Lattice.h" -#include "SentenceAlignment.h" -#include "Tunnel.h" -#include "TunnelCollection.h" -#include "RuleCollection.h" +#include "AlignedSentence.h" +#include "ConsistentPhrase.h" +#include "NonTerm.h" +#include "Parameter.h" using namespace std; -RuleElement::RuleElement(const RuleElement ©) -:m_latticeNode(copy.m_latticeNode) -,m_alignmentPos(copy.m_alignmentPos) +Rule::Rule(const NonTerm &lhsNonTerm, const AlignedSentence &alignedSentence) +:m_lhs(lhsNonTerm) +,m_alignedSentence(alignedSentence) +,m_isValid(true) +,m_canRecurse(true) { + CreateSource(); } - -Rule::Rule(const LatticeNode *latticeNode) -:m_lhs(NULL) +Rule::Rule(const Rule ©, const NonTerm &nonTerm) +:m_lhs(copy.m_lhs) +,m_alignedSentence(copy.m_alignedSentence) +,m_isValid(true) +,m_canRecurse(true) +,m_nonterms(copy.m_nonterms) { - RuleElement element(*latticeNode); - - m_coll.push_back(element); + m_nonterms.push_back(&nonTerm); + CreateSource(); + } -Rule::Rule(const Rule &prevRule, const LatticeNode *latticeNode) -:m_coll(prevRule.m_coll) -,m_lhs(NULL) -{ - RuleElement element(*latticeNode); - m_coll.push_back(element); +Rule::~Rule() { + // TODO Auto-generated destructor stub } -Rule::Rule(const Global &global, bool &isValid, const Rule ©, const LatticeNode *lhs, const SentenceAlignment &sentence) -:m_coll(copy.m_coll) -,m_source(copy.m_source) -,m_target(copy.m_target) -,m_lhs(lhs) -{ - CreateSymbols(global, isValid, sentence); +const ConsistentPhrase &Rule::GetConsistentPhrase() const +{ return m_lhs.GetConsistentPhrase(); } + +void Rule::CreateSource() +{ + const NonTerm *cp = NULL; + size_t nonTermInd = 0; + if (nonTermInd < m_nonterms.size()) { + cp = m_nonterms[nonTermInd]; + } + + for (int sourcePos = m_lhs.GetConsistentPhrase().corners[0]; + sourcePos <= m_lhs.GetConsistentPhrase().corners[1]; + ++sourcePos) { + + const RuleSymbol *ruleSymbol; + if (cp && cp->GetConsistentPhrase().corners[0] <= sourcePos && sourcePos <= cp->GetConsistentPhrase().corners[1]) { + // replace words with non-term + ruleSymbol = cp; + sourcePos = cp->GetConsistentPhrase().corners[1]; + if (m_nonterms.size()) { + cp = m_nonterms[nonTermInd]; + } + + // move to next non-term + ++nonTermInd; + cp = (nonTermInd < m_nonterms.size()) ? m_nonterms[nonTermInd] : NULL; + } + else { + // terminal + ruleSymbol = m_alignedSentence.GetPhrase(Moses::Input)[sourcePos]; + } + + m_source.Add(ruleSymbol); + } } -Rule::~Rule() +int Rule::GetNextSourcePosForNonTerm() const { + if (m_nonterms.empty()) { + // no non-terms so far. Can start next non-term on left corner + return m_lhs.GetConsistentPhrase().corners[0]; + } + else { + // next non-term can start just left of previous + const ConsistentPhrase &cp = m_nonterms.back()->GetConsistentPhrase(); + int nextPos = cp.corners[1] + 1; + return nextPos; + } } -// helper for sort -struct CompareLatticeNodeTarget +std::string Rule::Debug() const { - bool operator() (const RuleElement *a, const RuleElement *b) - { - const Range &rangeA = a->GetLatticeNode().GetTunnel().GetRange(1) - ,&rangeB = b->GetLatticeNode().GetTunnel().GetRange(1); - return rangeA.GetEndPos() < rangeB.GetEndPos(); - } -}; + stringstream out; -void Rule::CreateSymbols(const Global &global, bool &isValid, const SentenceAlignment &sentence) -{ - vector nonTerms; - - // source - for (size_t ind = 0; ind < m_coll.size(); ++ind) - { - RuleElement &element = m_coll[ind]; - const LatticeNode &node = element.GetLatticeNode(); - if (node.IsTerminal()) - { - size_t sourcePos = node.GetSourceRange().GetStartPos(); - const string &word = sentence.source[sourcePos]; - Symbol symbol(word, sourcePos); - m_source.Add(symbol); - } - else - { // non-term - const string &sourceWord = node.GetSyntaxNode(0).GetLabel(); - const string &targetWord = node.GetSyntaxNode(1).GetLabel(); - Symbol symbol(sourceWord, targetWord - , node.GetTunnel().GetRange(0).GetStartPos(), node.GetTunnel().GetRange(0).GetEndPos() - , node.GetTunnel().GetRange(1).GetStartPos(), node.GetTunnel().GetRange(1).GetEndPos() - , node.GetSyntaxNode(0).IsSyntax(), node.GetSyntaxNode(1).IsSyntax()); - m_source.Add(symbol); + // source + for (size_t i = 0; i < m_source.GetSize(); ++i) { + const RuleSymbol &symbol = *m_source[i]; + out << symbol.Debug() << " "; + } - // store current pos within phrase - element.m_alignmentPos.first = ind; + // target + out << "||| "; + for (size_t i = 0; i < m_target.GetSize(); ++i) { + const RuleSymbol &symbol = *m_target[i]; + out << symbol.Debug() << " "; + } - // for target symbols - nonTerms.push_back(&element); - } - - } - - // target - isValid = true; - - const Range &lhsTargetRange = m_lhs->GetTunnel().GetRange(1); + out << "||| "; + Alignments::const_iterator iterAlign; + for (iterAlign = m_alignments.begin(); iterAlign != m_alignments.end(); ++iterAlign) { + const std::pair &alignPair = *iterAlign; + out << alignPair.first << "-" << alignPair.second << " "; + } - // check spans of target non-terms - if (nonTerms.size()) - { - // sort non-term rules elements by target range - std::sort(nonTerms.begin(), nonTerms.end(), CompareLatticeNodeTarget()); + // overall range + out << "||| LHS=" << m_lhs.Debug(); - const Range &first = nonTerms.front()->GetLatticeNode().GetTunnel().GetRange(1); - const Range &last = nonTerms.back()->GetLatticeNode().GetTunnel().GetRange(1); - - if (first.GetStartPos() < lhsTargetRange.GetStartPos() - || last.GetEndPos() > lhsTargetRange.GetEndPos()) - { - isValid = false; - } - } - - if (isValid) - { - size_t indNonTerm = 0; - RuleElement *currNonTermElement = indNonTerm < nonTerms.size() ? nonTerms[indNonTerm] : NULL; - for (size_t targetPos = lhsTargetRange.GetStartPos(); targetPos <= lhsTargetRange.GetEndPos(); ++targetPos) - { - if (currNonTermElement && targetPos == currNonTermElement->GetLatticeNode().GetTunnel().GetRange(1).GetStartPos()) - { // start of a non-term. print out non-terms & skip to the end - - const LatticeNode &node = currNonTermElement->GetLatticeNode(); - - const string &sourceWord = node.GetSyntaxNode(0).GetLabel(); - const string &targetWord = node.GetSyntaxNode(1).GetLabel(); - Symbol symbol(sourceWord, targetWord - , node.GetTunnel().GetRange(0).GetStartPos(), node.GetTunnel().GetRange(0).GetEndPos() - , node.GetTunnel().GetRange(1).GetStartPos(), node.GetTunnel().GetRange(1).GetEndPos() - , node.GetSyntaxNode(0).IsSyntax(), node.GetSyntaxNode(1).IsSyntax()); - m_target.Add(symbol); - - // store current pos within phrase - currNonTermElement->m_alignmentPos.second = m_target.GetSize() - 1; - - assert(currNonTermElement->m_alignmentPos.first != NOT_FOUND); - - targetPos = node.GetTunnel().GetRange(1).GetEndPos(); - indNonTerm++; - currNonTermElement = indNonTerm < nonTerms.size() ? nonTerms[indNonTerm] : NULL; - } - else - { // term - const string &word = sentence.target[targetPos]; - - Symbol symbol(word, targetPos); - m_target.Add(symbol); - - } - } - - assert(indNonTerm == nonTerms.size()); - - if (m_target.GetSize() > global.maxSymbols) { - isValid = false; - //cerr << "m_source=" << m_source.GetSize() << ":" << m_source << endl; - //cerr << "m_target=" << m_target.GetSize() << ":" << m_target << endl; - } - } + return out.str(); } -bool Rule::MoreDefaultNonTermThanTerm() const +void Rule::Output(std::ostream &out, bool forward, const Parameter ¶ms) const { - size_t numTerm = 0, numDefaultNonTerm = 0; - - CollType::const_iterator iter; - for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) - { - const RuleElement &element = *iter; - const LatticeNode &node = element.GetLatticeNode(); - if (node.IsTerminal()) - { - ++numTerm; - } - else if (!node.IsSyntax()) - { - ++numDefaultNonTerm; - } - } - - bool ret = numDefaultNonTerm > numTerm; - return ret; + if (forward) { + // source + m_source.Output(out); + m_lhs.Output(out, Moses::Input); + + out << " ||| "; + + // target + m_target.Output(out); + m_lhs.Output(out, Moses::Output); + } + else { + // target + m_target.Output(out); + m_lhs.Output(out, Moses::Output); + + out << " ||| "; + + // source + m_source.Output(out); + m_lhs.Output(out, Moses::Input); + } + + out << " ||| "; + + // alignment + Alignments::const_iterator iterAlign; + for (iterAlign = m_alignments.begin(); iterAlign != m_alignments.end(); ++iterAlign) { + const std::pair &alignPair = *iterAlign; + + if (forward) { + out << alignPair.first << "-" << alignPair.second << " "; + } + else { + out << alignPair.second << "-" << alignPair.first << " "; + } + } + + out << "||| "; + + // count + out << m_count; + + out << " ||| "; + + // properties + + // span length + if (forward && params.spanLength && m_nonterms.size()) { + out << "{{SpanLength "; + + for (size_t i = 0; i < m_nonterms.size(); ++i) { + const NonTerm &nonTerm = *m_nonterms[i]; + const ConsistentPhrase &cp = nonTerm.GetConsistentPhrase(); + out << i << "," << cp.GetWidth(Moses::Input) << "," << cp.GetWidth(Moses::Output) << " "; + } + out << "}} "; + } + + // non-term context + if (forward && params.nonTermContext && m_nonterms.size()) { + out << "{{NonTermContext "; + + for (size_t i = 0; i < m_nonterms.size(); ++i) { + const NonTerm &nonTerm = *m_nonterms[i]; + const ConsistentPhrase &cp = nonTerm.GetConsistentPhrase(); + NonTermContext(i, cp, out); + } + out << "}} "; + } } -bool Rule::SourceHasEdgeDefaultNonTerm() const +void Rule::NonTermContext(size_t ntInd, const ConsistentPhrase &cp, std::ostream &out) const { - assert(m_coll.size()); - const LatticeNode &first = m_coll.front().GetLatticeNode(); - const LatticeNode &last = m_coll.back().GetLatticeNode(); + int startPos = cp.corners[0]; + int endPos = cp.corners[1]; + + const Phrase &source = m_alignedSentence.GetPhrase(Moses::Input); + + if (startPos == 0) { + out << " "; + } + else { + out << source[startPos - 1]->GetString() << " "; + } + + out << source[startPos]->GetString() << " "; + out << source[endPos]->GetString() << " "; + + if (endPos == source.size() - 1) { + out << " "; + } + else { + out << source[endPos + 1]->GetString() << " "; + } + - // 1st - if (!first.IsTerminal() && !first.IsSyntax()) - { - return true; - } - if (!last.IsTerminal() && !last.IsSyntax()) - { - return true; - } - - return false; } -bool Rule::IsValid(const Global &global, const TunnelCollection &tunnelColl) const +void Rule::Prevalidate(const Parameter ¶ms) { - if (m_coll.size() == 1 && !m_coll[0].GetLatticeNode().IsTerminal()) // can't be only 1 terminal - { - return false; - } + const ConsistentPhrase &cp = m_lhs.GetConsistentPhrase(); - if (MoreDefaultNonTermThanTerm()) - { // must have at least as many terms as non-syntax non-terms - return false; - } + // check number of source symbols in rule + if (m_source.GetSize() > params.maxSymbolsSource) { + m_isValid = false; + } - if (!global.allowDefaultNonTermEdge && SourceHasEdgeDefaultNonTerm()) - { - return false; - } - - if (GetNumSymbols() > global.maxSymbols) - { - return false; - } - - if (AdjacentDefaultNonTerms()) - { - return false; - } - - if (!IsHole(tunnelColl)) - { - return false; - } + // check that last non-term added isn't too small + if (m_nonterms.size()) { + const NonTerm &lastNonTerm = *m_nonterms.back(); + const ConsistentPhrase &cp = lastNonTerm.GetConsistentPhrase(); - if (NonTermOverlap()) - { - return false; - } - - /* - std::pair spanS = GetSpan(0) - ,spanT= GetSpan(1); + int sourceWidth = cp.corners[1] - cp.corners[0] + 1; + if (sourceWidth < params.minHoleSource) { + m_isValid = false; + m_canRecurse = false; + return; + } + } - if (tunnelColl.NumUnalignedWord(0, spanS.first, spanS.second) >= global.maxUnaligned) - return false; - if (tunnelColl.NumUnalignedWord(1, spanT.first, spanT.second) >= global.maxUnaligned) - return false; - */ - - return true; + // check number of non-terms + int numNonTerms = 0; + int numHieroNonTerms = 0; + for (size_t i = 0; i < m_source.GetSize(); ++i) { + const RuleSymbol *arc = m_source[i]; + if (arc->IsNonTerm()) { + ++numNonTerms; + const NonTerm &nonTerm = *static_cast(arc); + bool isHiero = nonTerm.IsHiero(params); + if (isHiero) { + ++numHieroNonTerms; + } + } + } + + if (numNonTerms >= params.maxNonTerm) { + m_canRecurse = false; + if (numNonTerms > params.maxNonTerm) { + m_isValid = false; + return; + } + } + + if (numHieroNonTerms >= params.maxHieroNonTerm) { + m_canRecurse = false; + if (numHieroNonTerms > params.maxHieroNonTerm) { + m_isValid = false; + return; + } + } + + // check if 2 consecutive non-terms in source + if (!params.nonTermConsecSource && m_nonterms.size() >= 2) { + const NonTerm &lastNonTerm = *m_nonterms.back(); + const NonTerm &secondLastNonTerm = *m_nonterms[m_nonterms.size() - 2]; + if (secondLastNonTerm.GetConsistentPhrase().corners[1] + 1 == + lastNonTerm.GetConsistentPhrase().corners[0]) { + if (params.mixedSyntaxType == 0) { + // ordinary hiero or syntax model + m_isValid = false; + m_canRecurse = false; + return; + } + else { + // Hieu's mixed syntax + if (lastNonTerm.IsHiero(Moses::Input, params) + && secondLastNonTerm.IsHiero(Moses::Input, params)) { + m_isValid = false; + m_canRecurse = false; + return; + } + } + + } + } + + //check to see if it overlaps with any other non-terms + if (m_nonterms.size() >= 2) { + const NonTerm &lastNonTerm = *m_nonterms.back(); + + for (size_t i = 0; i < m_nonterms.size() - 1; ++i) { + const NonTerm &otherNonTerm = *m_nonterms[i]; + bool overlap = lastNonTerm.GetConsistentPhrase().TargetOverlap(otherNonTerm.GetConsistentPhrase()); + + if (overlap) { + m_isValid = false; + m_canRecurse = false; + return; + } + } + } + + // check that at least 1 word is aligned + if (params.requireAlignedWord) { + bool ok = false; + for (size_t i = 0; i < m_source.GetSize(); ++i) { + const RuleSymbol &symbol = *m_source[i]; + if (!symbol.IsNonTerm()) { + const Word &word = static_cast(symbol); + if (word.GetAlignment().size()) { + ok = true; + break; + } + } + } + + if (!ok) { + m_isValid = false; + m_canRecurse = false; + return; + } + } + + if (params.maxSpanFreeNonTermSource) { + const NonTerm *front = dynamic_cast(m_source[0]); + if (front) { + int width = front->GetWidth(Moses::Input); + if (width > params.maxSpanFreeNonTermSource) { + m_isValid = false; + m_canRecurse = false; + return; + } + } + + const NonTerm *back = dynamic_cast(m_source.Back()); + if (back) { + int width = back->GetWidth(Moses::Input); + if (width > params.maxSpanFreeNonTermSource) { + m_isValid = false; + m_canRecurse = false; + return; + } + } + } + + if (!params.nieceTerminal) { + // collect terminal in a rule + std::set terms; + for (size_t i = 0; i < m_source.GetSize(); ++i) { + const Word *word = dynamic_cast(m_source[i]); + if (word) { + terms.insert(word); + } + } + + // look in non-terms + for (size_t i = 0; i < m_source.GetSize(); ++i) { + const NonTerm *nonTerm = dynamic_cast(m_source[i]); + if (nonTerm) { + const ConsistentPhrase &cp = nonTerm->GetConsistentPhrase(); + bool containTerm = ContainTerm(cp, terms); + + if (containTerm) { + //cerr << "ruleSource=" << *ruleSource << " "; + //cerr << "ntRange=" << ntRange << endl; + + // non-term contains 1 of the terms in the rule. + m_isValid = false; + m_canRecurse = false; + return; + } + } + } + } + + if (params.maxScope != UNDEFINED) { + int scope = CalcScope(); + if (scope > params.maxScope) { + m_isValid = false; + m_canRecurse = false; + return; + } + } } -bool Rule::NonTermOverlap() const +int Rule::CalcScope() const { - vector ranges; - - CollType::const_iterator iter; - for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) - { - const RuleElement &element = *iter; - if (!element.GetLatticeNode().IsTerminal()) - { - const Range &range = element.GetLatticeNode().GetTunnel().GetRange(1); - ranges.push_back(range); - } - } - - vector::const_iterator outerIter; - for (outerIter = ranges.begin(); outerIter != ranges.end(); ++outerIter) - { - const Range &outer = *outerIter; - vector::const_iterator innerIter; - for (innerIter = outerIter + 1; innerIter != ranges.end(); ++innerIter) - { - const Range &inner = *innerIter; - if (outer.Overlap(inner)) - return true; - } - } - - return false; + int scope = 0; + if (m_source.GetSize() > 1) { + const RuleSymbol &front = *m_source.Front(); + if (front.IsNonTerm()) { + ++scope; + } + + const RuleSymbol &back = *m_source.Back(); + if (back.IsNonTerm()) { + ++scope; + } + } + return scope; } -Range Rule::GetSourceRange() const +template +bool Contains(const T *sought, const set &coll) { - assert(m_coll.size()); - const Range &first = m_coll.front().GetLatticeNode().GetSourceRange(); - const Range &last = m_coll.back().GetLatticeNode().GetSourceRange(); - - Range ret(first.GetStartPos(), last.GetEndPos()); - return ret; -} - - -bool Rule::IsHole(const TunnelCollection &tunnelColl) const -{ - const Range &spanS = GetSourceRange(); - const TunnelList &tunnels = tunnelColl.GetTunnels(spanS.GetStartPos(), spanS.GetEndPos()); - - bool ret = tunnels.size() > 0; - return ret; -} - - -bool Rule::CanRecurse(const Global &global, const TunnelCollection &tunnelColl) const -{ - if (GetNumSymbols() >= global.maxSymbols) - return false; - if (AdjacentDefaultNonTerms()) - return false; - if (MaxNonTerm(global)) - return false; - if (NonTermOverlap()) - { - return false; - } - - const Range spanS = GetSourceRange(); - - if (tunnelColl.NumUnalignedWord(0, spanS.GetStartPos(), spanS.GetEndPos()) >= global.maxUnaligned) - return false; -// if (tunnelColl.NumUnalignedWord(1, spanT.first, spanT.second) >= global.maxUnaligned) -// return false; - - - return true; -} - -bool Rule::MaxNonTerm(const Global &global) const -{ - //cerr << *this << endl; - size_t numNonTerm = 0, numNonTermDefault = 0; - - CollType::const_iterator iter; - for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) - { - const LatticeNode *node = &(*iter).GetLatticeNode(); - if (!node->IsTerminal() ) - { - numNonTerm++; - if (!node->IsSyntax()) - { - numNonTermDefault++; - } - if (numNonTerm >= global.maxNonTerm || numNonTermDefault >= global.maxNonTermDefault) - return true; - } - } - - return false; -} - - -bool Rule::AdjacentDefaultNonTerms() const -{ - assert(m_coll.size() > 0); - - const LatticeNode *prevNode = &m_coll.front().GetLatticeNode(); - CollType::const_iterator iter; - for (iter = m_coll.begin() + 1; iter != m_coll.end(); ++iter) - { - const LatticeNode *node = &(*iter).GetLatticeNode(); - if (!prevNode->IsTerminal() && !node->IsTerminal() && !prevNode->IsSyntax() && !node->IsSyntax() ) - { + std::set::const_iterator iter; + for (iter = coll.begin(); iter != coll.end(); ++iter) { + const Word *found = *iter; + if (sought->CompareString(*found) == 0) { return true; } - prevNode = node; } - return false; } - - -size_t Rule::GetNumSymbols() const +bool Rule::ContainTerm(const ConsistentPhrase &cp, const std::set &terms) const { - size_t ret = m_coll.size(); - return ret; -} + const Phrase &sourceSentence = m_alignedSentence.GetPhrase(Moses::Input); -void Rule::CreateRules(RuleCollection &rules - , const Lattice &lattice - , const SentenceAlignment &sentence - , const Global &global) -{ - assert(m_coll.size() > 0); - const LatticeNode *latticeNode = &m_coll.back().GetLatticeNode(); - size_t endPos = latticeNode->GetSourceRange().GetEndPos() + 1; - - const Stack &stack = lattice.GetStack(endPos); - - Stack::const_iterator iter; - for (iter = stack.begin(); iter != stack.end(); ++iter) - { - const LatticeNode *newLatticeNode = *iter; - Rule *newRule = new Rule(*this, newLatticeNode); - //cerr << *newRule << endl; - - if (newRule->CanRecurse(global, sentence.GetTunnelCollection())) - { // may or maynot be valid, but can continue to build on this rule - newRule->CreateRules(rules, lattice, sentence, global); + for (int pos = cp.corners[0]; pos <= cp.corners[1]; ++pos) { + const Word *soughtWord = sourceSentence[pos]; + + // find same word in set + if (Contains(soughtWord, terms)) { + return true; } - - if (newRule->IsValid(global, sentence.GetTunnelCollection())) - { // add to rule collection - rules.Add(global, newRule, sentence); - } - else - { - delete newRule; + } + return false; +} + +bool CompareTargetNonTerms(const NonTerm *a, const NonTerm *b) +{ + // compare just start target pos + return a->GetConsistentPhrase().corners[2] < b->GetConsistentPhrase().corners[2]; +} + +void Rule::CreateTarget(const Parameter ¶ms) +{ + if (!m_isValid) { + return; + } + + vector targetNonTerm(m_nonterms); + std::sort(targetNonTerm.begin(), targetNonTerm.end(), CompareTargetNonTerms); + + const NonTerm *cp = NULL; + size_t nonTermInd = 0; + if (nonTermInd < targetNonTerm.size()) { + cp = targetNonTerm[nonTermInd]; + } + + for (int targetPos = m_lhs.GetConsistentPhrase().corners[2]; + targetPos <= m_lhs.GetConsistentPhrase().corners[3]; + ++targetPos) { + + const RuleSymbol *ruleSymbol; + if (cp && cp->GetConsistentPhrase().corners[2] <= targetPos && targetPos <= cp->GetConsistentPhrase().corners[3]) { + // replace words with non-term + ruleSymbol = cp; + targetPos = cp->GetConsistentPhrase().corners[3]; + if (targetNonTerm.size()) { + cp = targetNonTerm[nonTermInd]; + } + + // move to next non-term + ++nonTermInd; + cp = (nonTermInd < targetNonTerm.size()) ? targetNonTerm[nonTermInd] : NULL; + } + else { + // terminal + ruleSymbol = m_alignedSentence.GetPhrase(Moses::Output)[targetPos]; + } + + m_target.Add(ruleSymbol); + } + + CreateAlignments(); +} + + +void Rule::CreateAlignments() +{ + int sourceStart = GetConsistentPhrase().corners[0]; + int targetStart = GetConsistentPhrase().corners[2]; + + for (size_t sourcePos = 0; sourcePos < m_source.GetSize(); ++sourcePos) { + const RuleSymbol *symbol = m_source[sourcePos]; + if (!symbol->IsNonTerm()) { + // terminals + const Word &sourceWord = static_cast(*symbol); + const std::set &targetWords = sourceWord.GetAlignment(); + CreateAlignments(sourcePos, targetWords); + } + else { + // non-terms. same object in both source & target + CreateAlignments(sourcePos, symbol); + } + } +} + +void Rule::CreateAlignments(int sourcePos, const std::set &targetWords) +{ + std::set::const_iterator iterTarget; + for (iterTarget = targetWords.begin(); iterTarget != targetWords.end(); ++iterTarget) { + const Word *targetWord = *iterTarget; + CreateAlignments(sourcePos, targetWord); + } +} + +void Rule::CreateAlignments(int sourcePos, const RuleSymbol *targetSought) +{ + // should be in target phrase + for (size_t targetPos = 0; targetPos < m_target.GetSize(); ++targetPos) { + const RuleSymbol *foundSymbol = m_target[targetPos]; + if (targetSought == foundSymbol) { + pair alignPoint(sourcePos, targetPos); + m_alignments.insert(alignPoint); + return; } - - } -} - -bool Rule::operator<(const Rule &compare) const -{ - /* - if (g_debug) - { - cerr << *this << endl << compare; - cerr << endl; - } - */ - - bool ret = Compare(compare) < 0; - - /* - if (g_debug) - { - cerr << *this << endl << compare << endl << ret << endl << endl; - } - */ - - return ret; -} - -int Rule::Compare(const Rule &compare) const -{ - //cerr << *this << endl << compare << endl; - assert(m_coll.size() > 0); - assert(m_source.GetSize() > 0); - assert(m_target.GetSize() > 0); - - int ret = 0; - - // compare each fragment - ret = m_source.Compare(compare.m_source); - if (ret != 0) - { - return ret; } - ret = m_target.Compare(compare.m_target); - if (ret != 0) - { - return ret; - } - - // compare lhs - const string &thisSourceLabel = m_lhs->GetSyntaxNode(0).GetLabel(); - const string &otherSourceLabel = compare.m_lhs->GetSyntaxNode(0).GetLabel(); - if (thisSourceLabel != otherSourceLabel) - { - ret = (thisSourceLabel < otherSourceLabel) ? -1 : +1; - return ret; - } - - const string &thisTargetLabel = m_lhs->GetSyntaxNode(1).GetLabel(); - const string &otherTargetLabel = compare.m_lhs->GetSyntaxNode(1).GetLabel(); - if (thisTargetLabel != otherTargetLabel) - { - ret = (thisTargetLabel < otherTargetLabel) ? -1 : +1; - return ret; - } - - assert(ret == 0); - return ret; + throw "not found"; } - -const LatticeNode &Rule::GetLatticeNode(size_t ind) const -{ - assert(ind < m_coll.size()); - return m_coll[ind].GetLatticeNode(); -} - -void Rule::DebugOutput() const -{ - Output(cerr); -} - -void Rule::Output(std::ostream &out) const -{ - - stringstream strmeS, strmeT; - - std::vector::const_iterator iterSymbol; - for (iterSymbol = m_source.begin(); iterSymbol != m_source.end(); ++iterSymbol) - { - const Symbol &symbol = *iterSymbol; - strmeS << symbol << " "; - } - - for (iterSymbol = m_target.begin(); iterSymbol != m_target.end(); ++iterSymbol) - { - const Symbol &symbol = *iterSymbol; - strmeT << symbol << " "; - } - - // lhs - if (m_lhs) - { - strmeS << m_lhs->GetSyntaxNode(0).GetLabel(); - strmeT << m_lhs->GetSyntaxNode(1).GetLabel(); - } - - out << strmeS.str() << " ||| " << strmeT.str() << " ||| "; - - // alignment - Rule::CollType::const_iterator iter; - for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) - { - const RuleElement &element = *iter; - const LatticeNode &node = element.GetLatticeNode(); - bool isTerminal = node.IsTerminal(); - - if (!isTerminal) - { - out << element.m_alignmentPos.first << "-" << element.m_alignmentPos.second << " "; - } - } - - out << "||| 1"; - -} - -void Rule::OutputInv(std::ostream &out) const -{ - stringstream strmeS, strmeT; - - std::vector::const_iterator iterSymbol; - for (iterSymbol = m_source.begin(); iterSymbol != m_source.end(); ++iterSymbol) - { - const Symbol &symbol = *iterSymbol; - strmeS << symbol << " "; - } - - for (iterSymbol = m_target.begin(); iterSymbol != m_target.end(); ++iterSymbol) - { - const Symbol &symbol = *iterSymbol; - strmeT << symbol << " "; - } - - // lhs - if (m_lhs) - { - strmeS << m_lhs->GetSyntaxNode(0).GetLabel(); - strmeT << m_lhs->GetSyntaxNode(1).GetLabel(); - } - - out << strmeT.str() << " ||| " << strmeS.str() << " ||| "; - - // alignment - Rule::CollType::const_iterator iter; - for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) - { - const RuleElement &element = *iter; - const LatticeNode &node = element.GetLatticeNode(); - bool isTerminal = node.IsTerminal(); - - if (!isTerminal) - { - out << element.m_alignmentPos.second << "-" << element.m_alignmentPos.first << " "; - } - } - - out << "||| 1"; - -} - - diff --git a/contrib/other-builds/extract-mixed-syntax/Rule.h b/contrib/other-builds/extract-mixed-syntax/Rule.h index 3574094fe..e97dc6d7f 100644 --- a/contrib/other-builds/extract-mixed-syntax/Rule.h +++ b/contrib/other-builds/extract-mixed-syntax/Rule.h @@ -1,96 +1,87 @@ -#pragma once /* - * Rule.h - * extract - * - * Created by Hieu Hoang on 19/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. + * Rule.h * + * Created on: 20 Feb 2014 + * Author: hieu */ +#pragma once #include -#include -#include "LatticeNode.h" -#include "SymbolSequence.h" -#include "Global.h" +#include "Phrase.h" +#include "RulePhrase.h" +#include "moses/TypeDef.h" -class Lattice; -class SentenceAlignment; -class Global; -class RuleCollection; -class SyntaxNode; -class TunnelCollection; -class Range; +class ConsistentPhrase; +class AlignedSentence; +class NonTerm; +class Parameter; -class RuleElement -{ -protected: - const LatticeNode *m_latticeNode; + +class Rule { public: - std::pair m_alignmentPos; - - RuleElement(const RuleElement ©); - RuleElement(const LatticeNode &latticeNode) - :m_latticeNode(&latticeNode) - ,m_alignmentPos(NOT_FOUND, NOT_FOUND) - {} + typedef std::set > Alignments; - const LatticeNode &GetLatticeNode() const - { return *m_latticeNode; } + Rule(const Rule ©); // do not implement -}; + // original rule with no non-term + Rule(const NonTerm &lhsNonTerm, const AlignedSentence &alignedSentence); -class Rule -{ -protected: - typedef std::vector CollType; - CollType m_coll; - - const LatticeNode *m_lhs; - SymbolSequence m_source, m_target; - - bool IsHole(const TunnelCollection &tunnelColl) const; - bool NonTermOverlap() const; - - const LatticeNode &GetLatticeNode(size_t ind) const; - void CreateSymbols(const Global &global, bool &isValid, const SentenceAlignment &sentence); - -public: - // init - Rule(const LatticeNode *latticeNode); - - // create new rule by appending node to prev rule - Rule(const Rule &prevRule, const LatticeNode *latticeNode); - - // create copy with lhs - Rule(const Global &global, bool &isValid, const Rule ©, const LatticeNode *lhs, const SentenceAlignment &sentence); - - // can continue to add to this rule - bool CanRecurse(const Global &global, const TunnelCollection &tunnelColl) const; + // extend a rule, adding 1 new non-term + Rule(const Rule ©, const NonTerm &nonTerm); virtual ~Rule(); - // can add this to the set of rules - bool IsValid(const Global &global, const TunnelCollection &tunnelColl) const; + bool IsValid() const + { return m_isValid; } - size_t GetNumSymbols() const; - bool AdjacentDefaultNonTerms() const; - bool MaxNonTerm(const Global &global) const; - bool MoreDefaultNonTermThanTerm() const; - bool SourceHasEdgeDefaultNonTerm() const; + bool CanRecurse() const + { return m_canRecurse; } - void CreateRules(RuleCollection &rules - , const Lattice &lattice - , const SentenceAlignment &sentence - , const Global &global); - - int Compare(const Rule &compare) const; - bool operator<(const Rule &compare) const; - - Range GetSourceRange() const; - - DEBUG_OUTPUT(); + const NonTerm &GetLHS() const + { return m_lhs; } - void Output(std::ostream &out) const; - void OutputInv(std::ostream &out) const; + const ConsistentPhrase &GetConsistentPhrase() const; + + int GetNextSourcePosForNonTerm() const; + + void SetCount(float count) + { m_count = count; } + float GetCount() const + { return m_count; } + + const Alignments &GetAlignments() const + { return m_alignments; } + + std::string Debug() const; + void Output(std::ostream &out, bool forward, const Parameter ¶ms) const; + + void Prevalidate(const Parameter ¶ms); + void CreateTarget(const Parameter ¶ms); + + const RulePhrase &GetPhrase(Moses::FactorDirection direction) const + { return (direction == Moses::Input) ? m_source : m_target; } + +protected: + const NonTerm &m_lhs; + const AlignedSentence &m_alignedSentence; + RulePhrase m_source, m_target; + float m_count; + + Alignments m_alignments; + + // in source order + std::vector m_nonterms; + + bool m_isValid, m_canRecurse; + + void CreateSource(); + void CreateAlignments(); + void CreateAlignments(int sourcePos, const std::set &targetWords); + void CreateAlignments(int sourcePos, const RuleSymbol *targetSought); + + bool ContainTerm(const ConsistentPhrase &cp, const std::set &terms) const; + int CalcScope() const; // not yet correctly calculated + + void NonTermContext(size_t ntInd, const ConsistentPhrase &cp, std::ostream &out) const; }; + diff --git a/contrib/other-builds/extract-mixed-syntax/RuleCollection.cpp b/contrib/other-builds/extract-mixed-syntax/RuleCollection.cpp deleted file mode 100644 index 8389a70cf..000000000 --- a/contrib/other-builds/extract-mixed-syntax/RuleCollection.cpp +++ /dev/null @@ -1,102 +0,0 @@ -/* - * RuleCollection.cpp - * extract - * - * Created by Hieu Hoang on 19/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include "RuleCollection.h" -#include "Rule.h" -#include "SentenceAlignment.h" -#include "tables-core.h" -#include "Lattice.h" -#include "SyntaxTree.h" - -using namespace std; - -RuleCollection::~RuleCollection() -{ - RemoveAllInColl(m_coll); -} - -void RuleCollection::Add(const Global &global, Rule *rule, const SentenceAlignment &sentence) -{ - Range spanS = rule->GetSourceRange(); - - // cartesian product of lhs - Stack nontermNodes = sentence.GetLattice().GetNonTermNode(spanS); - Stack::const_iterator iterStack; - for (iterStack = nontermNodes.begin(); iterStack != nontermNodes.end(); ++iterStack) - { - const LatticeNode &node = **iterStack; - assert(!node.IsTerminal()); - - bool isValid; - // create rules with LHS - //cerr << "old:" << *rule << endl; - Rule *newRule = new Rule(global, isValid, *rule, &node, sentence); - - if (!isValid) - { // lhs doesn't match non-term spans - delete newRule; - continue; - } - - /* - stringstream s; - s << *newRule; - if (s.str().find("Wiederaufnahme der [X] ||| resumption of the [X] ||| ||| 1") == 0) - { - cerr << "READY:" << *newRule << endl; - g_debug = true; - } - else { - g_debug = false; - } - */ - - typedef set::iterator Iterator; - pair ret = m_coll.insert(newRule); - - if (ret.second) - { - //cerr << "ACCEPTED:" << *newRule << endl; - //cerr << ""; - } - else - { - //cerr << "REJECTED:" << *newRule << endl; - delete newRule; - } - - } - - delete rule; - -} - -void RuleCollection::Output(std::ostream &out) const -{ - RuleCollection::CollType::const_iterator iter; - for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) - { - const Rule &rule = **iter; - rule.Output(out); - out << endl; - } -} - -void RuleCollection::OutputInv(std::ostream &out) const -{ - RuleCollection::CollType::const_iterator iter; - for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) - { - const Rule &rule = **iter; - rule.OutputInv(out); - out << endl; - } -} - - - diff --git a/contrib/other-builds/extract-mixed-syntax/RuleCollection.h b/contrib/other-builds/extract-mixed-syntax/RuleCollection.h deleted file mode 100644 index 27d5d794a..000000000 --- a/contrib/other-builds/extract-mixed-syntax/RuleCollection.h +++ /dev/null @@ -1,55 +0,0 @@ -#pragma once -/* - * RuleCollection.h - * extract - * - * Created by Hieu Hoang on 19/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include "Rule.h" - -class SentenceAlignment; - -// helper for sort. Don't compare default non-terminals -struct CompareRule -{ - bool operator() (const Rule *a, const Rule *b) - { - /* - if (g_debug) - { - std::cerr << std::endl << (*a) << std::endl << (*b) << " "; - } - */ - bool ret = (*a) < (*b); - /* - if (g_debug) - { - std::cerr << ret << std::endl; - } - */ - return ret; - } -}; - - -class RuleCollection -{ -protected: - typedef std::set CollType; - CollType m_coll; - -public: - ~RuleCollection(); - void Add(const Global &global, Rule *rule, const SentenceAlignment &sentence); - size_t GetSize() const - { return m_coll.size(); } - - void Output(std::ostream &out) const; - void OutputInv(std::ostream &out) const; - -}; - diff --git a/contrib/other-builds/extract-mixed-syntax/RulePhrase.cpp b/contrib/other-builds/extract-mixed-syntax/RulePhrase.cpp new file mode 100644 index 000000000..5c629168b --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/RulePhrase.cpp @@ -0,0 +1,50 @@ +/* + * RulePhrase.cpp + * + * Created on: 26 Feb 2014 + * Author: hieu + */ + +#include +#include "RulePhrase.h" +#include "RuleSymbol.h" + +using namespace std; + +extern bool g_debug; + +int RulePhrase::Compare(const RulePhrase &other) const +{ + if (GetSize() != other.GetSize()) { + return GetSize() < other.GetSize() ? -1 : +1; + } + + for (size_t i = 0; i < m_coll.size(); ++i) { + const RuleSymbol &symbol = *m_coll[i]; + const RuleSymbol &otherSymbol = *other.m_coll[i]; + int compare = symbol.Compare(otherSymbol); + + if (compare) { + return compare; + } + } + + return 0; +} + +void RulePhrase::Output(std::ostream &out) const +{ + for (size_t i = 0; i < m_coll.size(); ++i) { + const RuleSymbol &symbol = *m_coll[i]; + symbol.Output(out); + out << " "; + } +} + +std::string RulePhrase::Debug() const +{ + std::stringstream out; + Output(out); + return out.str(); +} + diff --git a/contrib/other-builds/extract-mixed-syntax/RulePhrase.h b/contrib/other-builds/extract-mixed-syntax/RulePhrase.h new file mode 100644 index 000000000..412169b74 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/RulePhrase.h @@ -0,0 +1,49 @@ +/* + * RulePhrase.h + * + * Created on: 26 Feb 2014 + * Author: hieu + */ + +#ifndef RULEPHRASE_H_ +#define RULEPHRASE_H_ + +#include +#include +#include + +class RuleSymbol; + +// a phrase of terms and non-terms for 1 side of a rule +class RulePhrase +{ +public: + typedef std::vector Coll; + Coll m_coll; + + size_t GetSize() const + { return m_coll.size(); } + + void Add(const RuleSymbol *symbol) + { + m_coll.push_back(symbol); + } + + const RuleSymbol* operator[](size_t index) const { + return m_coll[index]; + } + + const RuleSymbol* Front() const { + return m_coll.front(); + } + const RuleSymbol* Back() const { + return m_coll.back(); + } + + int Compare(const RulePhrase &other) const; + + void Output(std::ostream &out) const; + std::string Debug() const; +}; + +#endif /* RULEPHRASE_H_ */ diff --git a/contrib/other-builds/extract-mixed-syntax/RuleSymbol.cpp b/contrib/other-builds/extract-mixed-syntax/RuleSymbol.cpp new file mode 100644 index 000000000..933ffc9c2 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/RuleSymbol.cpp @@ -0,0 +1,36 @@ +/* + * RuleSymbol.cpp + * + * Created on: 21 Feb 2014 + * Author: hieu + */ + +#include "RuleSymbol.h" + +using namespace std; + +RuleSymbol::RuleSymbol() { + // TODO Auto-generated constructor stub + +} + +RuleSymbol::~RuleSymbol() { + // TODO Auto-generated destructor stub +} + +int RuleSymbol::Compare(const RuleSymbol &other) const +{ + if (IsNonTerm() != other.IsNonTerm()) { + return IsNonTerm() ? -1 : +1; + } + + string str = GetString(); + string otherStr = other.GetString(); + + if (str == otherStr) { + return 0; + } + else { + return (str < otherStr) ? -1 : +1; + } +} diff --git a/contrib/other-builds/extract-mixed-syntax/RuleSymbol.h b/contrib/other-builds/extract-mixed-syntax/RuleSymbol.h new file mode 100644 index 000000000..c292fcc0d --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/RuleSymbol.h @@ -0,0 +1,31 @@ +/* + * RuleSymbol.h + * + * Created on: 21 Feb 2014 + * Author: hieu + */ + +#ifndef RULESYMBOL_H_ +#define RULESYMBOL_H_ + +#include +#include + +// base class - terminal or non-term +class RuleSymbol { +public: + RuleSymbol(); + virtual ~RuleSymbol(); + + virtual bool IsNonTerm() const = 0; + + virtual std::string Debug() const = 0; + virtual void Output(std::ostream &out) const = 0; + + virtual std::string GetString() const = 0; + + int Compare(const RuleSymbol &other) const; + +}; + +#endif /* RULESYMBOL_H_ */ diff --git a/contrib/other-builds/extract-mixed-syntax/Rules.cpp b/contrib/other-builds/extract-mixed-syntax/Rules.cpp new file mode 100644 index 000000000..1b93430e2 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Rules.cpp @@ -0,0 +1,227 @@ +/* + * Rules.cpp + * + * Created on: 20 Feb 2014 + * Author: hieu + */ + +#include +#include "Rules.h" +#include "ConsistentPhrase.h" +#include "ConsistentPhrases.h" +#include "AlignedSentence.h" +#include "Rule.h" +#include "Parameter.h" +#include "moses/Util.h" + +using namespace std; + +extern bool g_debug; + +Rules::Rules(const AlignedSentence &alignedSentence) +:m_alignedSentence(alignedSentence) +{ +} + +Rules::~Rules() { + Moses::RemoveAllInColl(m_keepRules); +} + +void Rules::CreateRules(const ConsistentPhrase &cp, + const Parameter ¶ms) +{ + if (params.hieroSourceLHS) { + const NonTerm &nonTerm = cp.GetHieroNonTerm(); + CreateRule(nonTerm, params); + } + else { + const ConsistentPhrase::NonTerms &nonTerms = cp.GetNonTerms(); + for (size_t i = 0; i < nonTerms.size(); ++i) { + const NonTerm &nonTerm = nonTerms[i]; + CreateRule(nonTerm, params); + } + } +} + +void Rules::CreateRule(const NonTerm &nonTerm, + const Parameter ¶ms) +{ + Rule *rule = new Rule(nonTerm, m_alignedSentence); + + rule->Prevalidate(params); + rule->CreateTarget(params); + + + if (rule->CanRecurse()) { + Extend(*rule, params); + } + + if (rule->IsValid()) { + m_keepRules.insert(rule); + } + else { + delete rule; + } + +} + +void Rules::Extend(const Parameter ¶ms) +{ + const ConsistentPhrases &allCPS = m_alignedSentence.GetConsistentPhrases(); + + size_t size = m_alignedSentence.GetPhrase(Moses::Input).size(); + for (size_t sourceStart = 0; sourceStart < size; ++sourceStart) { + for (size_t sourceEnd = sourceStart; sourceEnd < size; ++sourceEnd) { + const ConsistentPhrases::Coll &cps = allCPS.GetColl(sourceStart, sourceEnd); + + ConsistentPhrases::Coll::const_iterator iter; + for (iter = cps.begin(); iter != cps.end(); ++iter) { + const ConsistentPhrase &cp = **iter; + CreateRules(cp, params); + } + } + } +} + +void Rules::Extend(const Rule &rule, const Parameter ¶ms) +{ + const ConsistentPhrases &allCPS = m_alignedSentence.GetConsistentPhrases(); + int sourceMin = rule.GetNextSourcePosForNonTerm(); + + int ruleStart = rule.GetConsistentPhrase().corners[0]; + int ruleEnd = rule.GetConsistentPhrase().corners[1]; + + for (int sourceStart = sourceMin; sourceStart <= ruleEnd; ++sourceStart) { + for (int sourceEnd = sourceStart; sourceEnd <= ruleEnd; ++sourceEnd) { + if (sourceStart == ruleStart && sourceEnd == ruleEnd) { + // don't cover whole rule with 1 non-term + continue; + } + + const ConsistentPhrases::Coll &cps = allCPS.GetColl(sourceStart, sourceEnd); + Extend(rule, cps, params); + } + } +} + +void Rules::Extend(const Rule &rule, const ConsistentPhrases::Coll &cps, const Parameter ¶ms) +{ + ConsistentPhrases::Coll::const_iterator iter; + for (iter = cps.begin(); iter != cps.end(); ++iter) { + const ConsistentPhrase &cp = **iter; + Extend(rule, cp, params); + } +} + +void Rules::Extend(const Rule &rule, const ConsistentPhrase &cp, const Parameter ¶ms) +{ + const ConsistentPhrase::NonTerms &nonTerms = cp.GetNonTerms(); + for (size_t i = 0; i < nonTerms.size(); ++i) { + const NonTerm &nonTerm = nonTerms[i]; + + Rule *newRule = new Rule(rule, nonTerm); + newRule->Prevalidate(params); + newRule->CreateTarget(params); + + if (newRule->CanRecurse()) { + // recursively extend + Extend(*newRule, params); + } + + if (newRule->IsValid()) { + m_keepRules.insert(newRule); + } + else { + delete newRule; + } + } +} + +std::string Rules::Debug() const +{ + stringstream out; + + std::set::const_iterator iter; + out << "m_keepRules:" << endl; + for (iter = m_keepRules.begin(); iter != m_keepRules.end(); ++iter) { + const Rule &rule = **iter; + out << rule.Debug() << endl; + } + + return out.str(); +} + +void Rules::Output(std::ostream &out, bool forward, const Parameter ¶ms) const +{ + std::set::const_iterator iter; + for (iter = m_mergeRules.begin(); iter != m_mergeRules.end(); ++iter) { + const Rule &rule = **iter; + rule.Output(out, forward, params); + out << endl; + } +} + +void Rules::Consolidate(const Parameter ¶ms) +{ + if (params.fractionalCounting) { + CalcFractionalCount(); + } + else { + std::set::iterator iter; + for (iter = m_keepRules.begin(); iter != m_keepRules.end(); ++iter) { + Rule &rule = **iter; + rule.SetCount(1); + } + } + + MergeRules(params); +} + +void Rules::MergeRules(const Parameter ¶ms) +{ + typedef std::set MergeRules; + + std::set::const_iterator iterOrig; + for (iterOrig = m_keepRules.begin(); iterOrig != m_keepRules.end(); ++iterOrig) { + Rule *origRule = *iterOrig; + + pair inserted = m_mergeRules.insert(origRule); + if (!inserted.second) { + // already there, just add count + Rule &rule = **inserted.first; + float newCount = rule.GetCount() + origRule->GetCount(); + rule.SetCount(newCount); + } + } +} + +void Rules::CalcFractionalCount() +{ + typedef std::set RuleColl; + typedef std::map RuleByConsistentPhrase; + RuleByConsistentPhrase allRules; + + // sort by source AND target ranges + std::set::const_iterator iter; + for (iter = m_keepRules.begin(); iter != m_keepRules.end(); ++iter) { + Rule *rule = *iter; + const ConsistentPhrase &cp = rule->GetConsistentPhrase(); + RuleColl &ruleColl = allRules[&cp]; + ruleColl.insert(rule); + } + + // fractional count + RuleByConsistentPhrase::iterator iterOuter; + for (iterOuter = allRules.begin(); iterOuter != allRules.end(); ++iterOuter) { + RuleColl &rules = iterOuter->second; + + RuleColl::iterator iterInner; + for (iterInner = rules.begin(); iterInner != rules.end(); ++iterInner) { + Rule &rule = **iterInner; + rule.SetCount(1.0f / (float) rules.size()); + } + } + +} + + diff --git a/contrib/other-builds/extract-mixed-syntax/Rules.h b/contrib/other-builds/extract-mixed-syntax/Rules.h new file mode 100644 index 000000000..6d8cb122d --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Rules.h @@ -0,0 +1,72 @@ +/* + * Rules.h + * + * Created on: 20 Feb 2014 + * Author: hieu + */ + +#pragma once + +#include +#include +#include "ConsistentPhrases.h" +#include "Rule.h" + +extern bool g_debug; + +class AlignedSentence; +class Parameter; + +struct CompareRules { + bool operator()(const Rule *a, const Rule *b) + { + int compare; + + compare = a->GetPhrase(Moses::Input).Compare(b->GetPhrase(Moses::Input)); + if (compare) return compare < 0; + + compare = a->GetPhrase(Moses::Output).Compare(b->GetPhrase(Moses::Output)); + if (compare) return compare < 0; + + if (a->GetAlignments() != b->GetAlignments()) { + return a->GetAlignments() < b->GetAlignments(); + } + + if (a->GetLHS().GetString() != b->GetLHS().GetString()) { + return a->GetLHS().GetString() < b->GetLHS().GetString(); + } + + return false; + } +}; + +class Rules { +public: + Rules(const AlignedSentence &alignedSentence); + virtual ~Rules(); + void Extend(const Parameter ¶ms); + void Consolidate(const Parameter ¶ms); + + std::string Debug() const; + void Output(std::ostream &out, bool forward, const Parameter ¶ms) const; + +protected: + const AlignedSentence &m_alignedSentence; + std::set m_keepRules; + std::set m_mergeRules; + + void Extend(const Rule &rule, const Parameter ¶ms); + void Extend(const Rule &rule, const ConsistentPhrases::Coll &cps, const Parameter ¶ms); + void Extend(const Rule &rule, const ConsistentPhrase &cp, const Parameter ¶ms); + + // create original rules + void CreateRules(const ConsistentPhrase &cp, + const Parameter ¶ms); + void CreateRule(const NonTerm &nonTerm, + const Parameter ¶ms); + + void MergeRules(const Parameter ¶ms); + void CalcFractionalCount(); + +}; + diff --git a/contrib/other-builds/extract-mixed-syntax/SentenceAlignment.cpp b/contrib/other-builds/extract-mixed-syntax/SentenceAlignment.cpp deleted file mode 100644 index b13743bc1..000000000 --- a/contrib/other-builds/extract-mixed-syntax/SentenceAlignment.cpp +++ /dev/null @@ -1,331 +0,0 @@ -/* - * SentenceAlignment.cpp - * extract - * - * Created by Hieu Hoang on 19/01/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include -#include "SentenceAlignment.h" -#include "XmlTree.h" -#include "tables-core.h" -#include "TunnelCollection.h" -#include "Lattice.h" -#include "LatticeNode.h" - -using namespace std; - -extern std::set< std::string > targetLabelCollection, sourceLabelCollection; -extern std::map< std::string, int > targetTopLabelCollection, sourceTopLabelCollection; - -SentenceAlignment::SentenceAlignment() -:m_tunnelCollection(NULL) -,m_lattice(NULL) -{} - -SentenceAlignment::~SentenceAlignment() -{ - delete m_tunnelCollection; - delete m_lattice; -} - -int SentenceAlignment::Create( const std::string &targetString, const std::string &sourceString, const std::string &alignmentString, int sentenceID, const Global &global ) -{ - - // tokenizing English (and potentially extract syntax spans) - if (global.targetSyntax) { - string targetStringCPP = string(targetString); - ProcessAndStripXMLTags( targetStringCPP, targetTree, targetLabelCollection , targetTopLabelCollection ); - target = tokenize( targetStringCPP.c_str() ); - // cerr << "E: " << targetStringCPP << endl; - } - else { - target = tokenize( targetString.c_str() ); - } - - // tokenizing source (and potentially extract syntax spans) - if (global.sourceSyntax) { - string sourceStringCPP = string(sourceString); - ProcessAndStripXMLTags( sourceStringCPP, sourceTree, sourceLabelCollection , sourceTopLabelCollection ); - source = tokenize( sourceStringCPP.c_str() ); - // cerr << "F: " << sourceStringCPP << endl; - } - else { - source = tokenize( sourceString.c_str() ); - } - - // check if sentences are empty - if (target.size() == 0 || source.size() == 0) { - cerr << "no target (" << target.size() << ") or source (" << source.size() << ") words << end insentence " << sentenceID << endl; - cerr << "T: " << targetString << endl << "S: " << sourceString << endl; - return 0; - } - - // prepare data structures for alignments - for(int i=0; i dummy; - alignedToT.push_back( dummy ); - } - - //InitTightest(m_s2tTightest, source.size()); - //InitTightest(m_t2sTightest, target.size()); - - - // reading in alignments - vector alignmentSequence = tokenize( alignmentString.c_str() ); - for(int i=0; i= target.size() || s >= source.size()) { - cerr << "WARNING: sentence " << sentenceID << " has alignment point (" << s << ", " << t << ") out of bounds (" << source.size() << ", " << target.size() << ")\n"; - cerr << "T: " << targetString << endl << "S: " << sourceString << endl; - return 0; - } - alignedToT[t].push_back( s ); - alignedCountS[s]++; - - //SetAlignment(s, t); - } - - bool mixed = global.mixed; - sourceTree.AddDefaultNonTerms(global.sourceSyntax, mixed, source.size()); - targetTree.AddDefaultNonTerms(global.targetSyntax, mixed, target.size()); - - //CalcTightestSpan(m_s2tTightest); - //CalcTightestSpan(m_t2sTightest); - - return 1; -} - -/* -void SentenceAlignment::InitTightest(Outer &tightest, size_t len) -{ - tightest.resize(len); - - for (size_t posOuter = 0; posOuter < len; ++posOuter) - { - Inner &inner = tightest[posOuter]; - size_t innerSize = len - posOuter; - inner.resize(innerSize); - - } -} - -void SentenceAlignment::CalcTightestSpan(Outer &tightest) -{ - size_t len = tightest.size(); - - for (size_t startPos = 0; startPos < len; ++startPos) - { - for (size_t endPos = startPos + 1; endPos < len; ++endPos) - { - const Range &prevRange = GetTightest(tightest, startPos, endPos - 1); - const Range &smallRange = GetTightest(tightest, endPos, endPos); - Range &newRange = GetTightest(tightest, startPos, endPos); - - newRange.Merge(prevRange, smallRange); - //cerr << "[" << startPos << "-" << endPos << "] --> [" << newRange.GetStartPos() << "-" << newRange.GetEndPos() << "]"; - } - } -} - -Range &SentenceAlignment::GetTightest(Outer &tightest, size_t startPos, size_t endPos) -{ - assert(endPos < tightest.size()); - assert(endPos >= startPos); - - Inner &inner = tightest[startPos]; - - size_t ind = endPos - startPos; - Range &ret = inner[ind]; - return ret; -} - -void SentenceAlignment::SetAlignment(size_t source, size_t target) -{ - SetAlignment(m_s2tTightest, source, target); - SetAlignment(m_t2sTightest, target, source); -} - -void SentenceAlignment::SetAlignment(Outer &tightest, size_t thisPos, size_t thatPos) -{ - - Range &range = GetTightest(tightest, thisPos, thisPos); - if (range.GetStartPos() == NOT_FOUND) - { // not yet set, do them both - assert(range.GetEndPos() == NOT_FOUND); - range.SetStartPos(thatPos); - range.SetEndPos(thatPos); - } - else - { - assert(range.GetEndPos() != NOT_FOUND); - range.SetStartPos( (range.GetStartPos() > thatPos) ? thatPos : range.GetStartPos() ); - range.SetEndPos( (range.GetEndPos() < thatPos) ? thatPos : range.GetEndPos() ); - } -} - */ - - -void SentenceAlignment::FindTunnels(const Global &global ) -{ - int countT = target.size(); - int countS = source.size(); - int maxSpan = max(global.maxHoleSpanSourceDefault, global.maxHoleSpanSourceSyntax); - - m_tunnelCollection = new TunnelCollection(countS); - - m_tunnelCollection->alignedCountS = alignedCountS; - m_tunnelCollection->alignedCountT.resize(alignedToT.size()); - for (size_t ind = 0; ind < alignedToT.size(); ind++) - { - m_tunnelCollection->alignedCountT[ind] = alignedToT[ind].size(); - } - - // phrase repository for creating hiero phrases - - // check alignments for target phrase startT...endT - for(int lengthT=1; - lengthT <= maxSpan && lengthT <= countT; - lengthT++) { - for(int startT=0; startT < countT-(lengthT-1); startT++) { - - // that's nice to have - int endT = startT + lengthT - 1; - - // if there is target side syntax, there has to be a node - if (global.targetSyntax && !targetTree.HasNode(startT,endT)) - continue; - - // find find aligned source words - // first: find minimum and maximum source word - int minS = 9999; - int maxS = -1; - vector< int > usedS = alignedCountS; - for(int ti=startT;ti<=endT;ti++) { - for(int i=0;imaxS) { maxS = si; } - usedS[ si ]--; - } - } - - // unaligned phrases are not allowed - if( maxS == -1 ) - continue; - - // source phrase has to be within limits - if( maxS-minS >= maxSpan ) - { - continue; - } - - // check if source words are aligned to out of bound target words - bool out_of_bounds = false; - for(int si=minS;si<=maxS && !out_of_bounds;si++) - { - if (usedS[si]>0) { - out_of_bounds = true; - } - } - - // if out of bound, you gotta go - if (out_of_bounds) - continue; - - if (m_tunnelCollection->NumUnalignedWord(1, startT, endT) >= global.maxUnaligned) - continue; - - // done with all the checks, lets go over all consistent phrase pairs - // start point of source phrase may retreat over unaligned - for(int startS=minS; - (startS>=0 && - startS>maxS - maxSpan && // within length limit - (startS==minS || alignedCountS[startS]==0)); // unaligned - startS--) - { - // end point of source phrase may advance over unaligned - for(int endS=maxS; - (endSNumUnalignedWord(0, startS, endS) >= global.maxUnaligned) - continue; - - // take note that this is a valid phrase alignment - m_tunnelCollection->Add(startS, endS, startT, endT); - } - } - } - } - - //cerr << *tunnelCollection << endl; - -} - -void SentenceAlignment::CreateLattice(const Global &global) -{ - size_t countS = source.size(); - m_lattice = new Lattice(countS); - - for (size_t startPos = 0; startPos < countS; ++startPos) - { - //cerr << "creating arcs for " << startPos << "="; - m_lattice->CreateArcs(startPos, *m_tunnelCollection, *this, global); - - //cerr << LatticeNode::s_count << endl; - } -} - -void SentenceAlignment::CreateRules(const Global &global) -{ - size_t countS = source.size(); - - for (size_t startPos = 0; startPos < countS; ++startPos) - { - //cerr << "creating rules for " << startPos << "\n"; - m_lattice->CreateRules(startPos, *this, global); - } -} - -void OutputSentenceStr(std::ostream &out, const std::vector &vec) -{ - for (size_t pos = 0; pos < vec.size(); ++pos) - { - out << vec[pos] << " "; - } -} - -std::ostream& operator<<(std::ostream &out, const SentenceAlignment &obj) -{ - OutputSentenceStr(out, obj.target); - out << " ==> "; - OutputSentenceStr(out, obj.source); - out << endl; - - out << *obj.m_tunnelCollection; - - if (obj.m_lattice) - out << endl << *obj.m_lattice; - - return out; -} - - - - diff --git a/contrib/other-builds/extract-mixed-syntax/SentenceAlignment.h b/contrib/other-builds/extract-mixed-syntax/SentenceAlignment.h deleted file mode 100644 index a94941309..000000000 --- a/contrib/other-builds/extract-mixed-syntax/SentenceAlignment.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once -/* - * SentenceAlignment.h - * extract - * - * Created by Hieu Hoang on 19/01/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include -#include "SyntaxTree.h" -#include "Global.h" -#include "Range.h" - -class TunnelCollection; -class Lattice; - -class SentenceAlignment -{ - friend std::ostream& operator<<(std::ostream&, const SentenceAlignment&); - -public: - std::vector target; - std::vector source; - std::vector alignedCountS; - std::vector< std::vector > alignedToT; - SyntaxTree sourceTree, targetTree; - - //typedef std::vector Inner; - //typedef std::vector Outer; - - //Outer m_s2tTightest, m_t2sTightest; - - SentenceAlignment(); - ~SentenceAlignment(); - int Create(const std::string &targetString, const std::string &sourceString, const std::string &alignmentString, int sentenceID, const Global &global); - // void clear() { delete(alignment); }; - void FindTunnels( const Global &global ) ; - - void CreateLattice(const Global &global); - void CreateRules(const Global &global); - - const TunnelCollection &GetTunnelCollection() const - { - assert(m_tunnelCollection); - return *m_tunnelCollection; - } - - const Lattice &GetLattice() const - { - assert(m_lattice); - return *m_lattice; - } - -protected: - TunnelCollection *m_tunnelCollection; - Lattice *m_lattice; - - /* - void CalcTightestSpan(Outer &tightest); - void InitTightest(Outer &tightest, size_t len); - Range &GetTightest(Outer &tightest, size_t startPos, size_t endPos); - void SetAlignment(size_t source, size_t target); - void SetAlignment(Outer &tightest, size_t thisPos, size_t thatPos); - */ -}; - diff --git a/contrib/other-builds/extract-mixed-syntax/Symbol.cpp b/contrib/other-builds/extract-mixed-syntax/Symbol.cpp deleted file mode 100644 index 0181dcaeb..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Symbol.cpp +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Symbol.cpp - * extract - * - * Created by Hieu Hoang on 21/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include "Symbol.h" - -using namespace std; - -Symbol::Symbol(const std::string &label, size_t pos) -:m_label(label) -,m_isTerminal(true) -,m_span(2) -{ - m_span[0].first = pos; -} - -Symbol::Symbol(const std::string &labelS, const std::string &labelT - , size_t startS, size_t endS - , size_t startT, size_t endT - , bool isSourceSyntax, bool isTargetSyntax) -:m_label(labelS) -,m_labelT(labelT) -,m_isTerminal(false) -,m_span(2) -,m_isSourceSyntax(isSourceSyntax) -,m_isTargetSyntax(isTargetSyntax) -{ - m_span[0] = std::pair(startS, endS); - m_span[1] = std::pair(startT, endT); -} - -int CompareNonTerm(bool thisIsSyntax, bool otherIsSyntax - , const std::pair &thisSpan, const std::pair &otherSpan - , std::string thisLabel, std::string otherLabel) -{ - if (thisIsSyntax != otherIsSyntax) - { // 1 is [X] & the other is [NP] on the source - return thisIsSyntax ? -1 : +1; - } - - assert(thisIsSyntax == otherIsSyntax); - if (thisIsSyntax) - { // compare span & label - if (thisSpan != otherSpan) - return thisSpan < otherSpan ? -1 : +1; - if (thisLabel != otherLabel) - return thisLabel < otherLabel ? -1 : +1; - } - - return 0; -} - -int Symbol::Compare(const Symbol &other) const -{ - if (m_isTerminal != other.m_isTerminal) - return m_isTerminal ? -1 : +1; - - assert(m_isTerminal == other.m_isTerminal); - if (m_isTerminal) - { // compare labels & pos - if (m_span[0].first != other.m_span[0].first) - return (m_span[0].first < other.m_span[0].first) ? -1 : +1; - - if (m_label != other.m_label) - return (m_label < other.m_label) ? -1 : +1; - - } - else - { // non terms - int ret = CompareNonTerm(m_isSourceSyntax, other.m_isSourceSyntax - ,m_span[0], other.m_span[0] - ,m_label, other.m_label); - if (ret != 0) - return ret; - - ret = CompareNonTerm(m_isTargetSyntax, other.m_isTargetSyntax - ,m_span[1], other.m_span[1] - ,m_label, other.m_label); - if (ret != 0) - return ret; - } - - return 0; -} - - -std::ostream& operator<<(std::ostream &out, const Symbol &obj) -{ - if (obj.m_isTerminal) - out << obj.m_label; - else - out << obj.m_label + obj.m_labelT; - - return out; -} - diff --git a/contrib/other-builds/extract-mixed-syntax/Symbol.h b/contrib/other-builds/extract-mixed-syntax/Symbol.h deleted file mode 100644 index b79a705b2..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Symbol.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -/* - * Symbol.h - * extract - * - * Created by Hieu Hoang on 21/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include - -class Symbol -{ - friend std::ostream& operator<<(std::ostream &out, const Symbol &obj); - -protected: - std::string m_label, m_labelT; // m_labelT only for non-term - std::vector > m_span; - - bool m_isTerminal, m_isSourceSyntax, m_isTargetSyntax; -public: - // for terminals - Symbol(const std::string &label, size_t pos); - - // for non-terminals - Symbol(const std::string &labelS, const std::string &labelT - , size_t startS, size_t endS - , size_t startT, size_t endT - , bool isSourceSyntax, bool isTargetSyntax); - - int Compare(const Symbol &other) const; - -}; diff --git a/contrib/other-builds/extract-mixed-syntax/SymbolSequence.cpp b/contrib/other-builds/extract-mixed-syntax/SymbolSequence.cpp deleted file mode 100644 index 0cf19f664..000000000 --- a/contrib/other-builds/extract-mixed-syntax/SymbolSequence.cpp +++ /dev/null @@ -1,56 +0,0 @@ -/* - * SymbolSequence.cpp - * extract - * - * Created by Hieu Hoang on 21/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include "SymbolSequence.h" - -using namespace std; - -int SymbolSequence::Compare(const SymbolSequence &other) const -{ - int ret; - size_t thisSize = GetSize(); - size_t otherSize = other.GetSize(); - if (thisSize != otherSize) - { - ret = (thisSize < otherSize) ? -1 : +1; - return ret; - } - else - { - assert(thisSize == otherSize); - for (size_t ind = 0; ind < thisSize; ++ind) - { - const Symbol &thisSymbol = GetSymbol(ind); - const Symbol &otherSymbol = other.GetSymbol(ind); - ret = thisSymbol.Compare(otherSymbol); - if (ret != 0) - { - return ret; - } - } - } - - assert(ret == 0); - return ret; -} - -std::ostream& operator<<(std::ostream &out, const SymbolSequence &obj) -{ - SymbolSequence::CollType::const_iterator iterSymbol; - for (iterSymbol = obj.m_coll.begin(); iterSymbol != obj.m_coll.end(); ++iterSymbol) - { - const Symbol &symbol = *iterSymbol; - out << symbol << " "; - } - - return out; -} - - diff --git a/contrib/other-builds/extract-mixed-syntax/SymbolSequence.h b/contrib/other-builds/extract-mixed-syntax/SymbolSequence.h deleted file mode 100644 index 997c24205..000000000 --- a/contrib/other-builds/extract-mixed-syntax/SymbolSequence.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once -/* - * SymbolSequence.h - * extract - * - * Created by Hieu Hoang on 21/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include "Symbol.h" - -class SymbolSequence -{ - friend std::ostream& operator<<(std::ostream &out, const SymbolSequence &obj); - -protected: - typedef std::vector CollType; - CollType m_coll; - -public: - typedef CollType::iterator iterator; - typedef CollType::const_iterator const_iterator; - const_iterator begin() const { return m_coll.begin(); } - const_iterator end() const { return m_coll.end(); } - - void Add(const Symbol &symbol) - { - m_coll.push_back(symbol); - } - size_t GetSize() const - { return m_coll.size(); } - const Symbol &GetSymbol(size_t ind) const - { return m_coll[ind]; } - - void Clear() - { m_coll.clear(); } - - int Compare(const SymbolSequence &other) const; - -}; diff --git a/contrib/other-builds/extract-mixed-syntax/SyntaxTree.cpp b/contrib/other-builds/extract-mixed-syntax/SyntaxTree.cpp index a6ba3de7b..472444e7c 100644 --- a/contrib/other-builds/extract-mixed-syntax/SyntaxTree.cpp +++ b/contrib/other-builds/extract-mixed-syntax/SyntaxTree.cpp @@ -1,245 +1,47 @@ -// $Id: SyntaxTree.cpp 1960 2008-12-15 12:52:38Z phkoehn $ -// vim:tabstop=2 - -/*********************************************************************** - Moses - factored phrase-based language decoder - Copyright (C) 2009 University of Edinburgh - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - ***********************************************************************/ - - -#include #include +#include #include "SyntaxTree.h" -//#include "extract.h" -#include "Global.h" - -//extern const Global g_debug; -extern const Global *g_global; +#include "Parameter.h" using namespace std; -bool SyntaxNode::IsSyntax() const +void SyntaxTree::Add(int startPos, int endPos, const std::string &label, const Parameter ¶ms) { - bool ret = GetLabel() != "[X]"; - return ret; -} + //cerr << "add " << label << " to " << "[" << startPos << "-" << endPos << "]" << endl; -SyntaxTree::SyntaxTree() -:m_defaultLHS(0,0, "[X]") -{ - m_emptyNode.clear(); -} + Range range(startPos, endPos); + Labels &labels = m_coll[range]; -SyntaxTree::~SyntaxTree() -{ - // loop through all m_nodes, delete them - for(int i=0; iuppermostOnly) - { - nodesChart.push_back( newNode ); - //assert(!HasDuplicates(m_index[ startPos ][ endPos ])); - } - else - { - if (nodesChart.size() > 0) - { - assert(nodesChart.size() == 1); - //delete nodes[0]; - nodesChart.resize(0); + bool add = true; + if (labels.size()) { + if (params.multiLabel == 1) { + // delete the label in collection and add new + assert(labels.size() == 1); + labels.clear(); } - assert(nodesChart.size() == 0); - nodesChart.push_back( newNode ); - } -} - -ParentNodes SyntaxTree::Parse() { - ParentNodes parents; - - int size = m_index.size(); - - // looping through all spans of size >= 2 - for( int length=2; length<=size; length++ ) - { - for( int startPos = 0; startPos <= size-length; startPos++ ) - { - if (HasNode( startPos, startPos+length-1 )) - { - // processing one (parent) span - - //std::cerr << "# " << startPos << "-" << (startPos+length-1) << ":"; - SplitPoints splitPoints; - splitPoints.push_back( startPos ); - //std::cerr << " " << startPos; - - int first = 1; - int covered = 0; - while( covered < length ) - { - // find largest covering subspan (child) - // starting at last covered position - for( int midPos=length-first; midPos>covered; midPos-- ) - { - if( HasNode( startPos+covered, startPos+midPos-1 ) ) - { - covered = midPos; - splitPoints.push_back( startPos+covered ); - // std::cerr << " " << ( startPos+covered ); - first = 0; - } - } - } - // std::cerr << std::endl; - parents.push_back( splitPoints ); - } + else if (params.multiLabel == 2) { + // ignore this label + add = false; } } - return parents; -} -bool SyntaxTree::HasNode( int startPos, int endPos ) const -{ - return GetNodes( startPos, endPos).size() > 0; -} - -const SyntaxNodes &SyntaxTree::GetNodes( int startPos, int endPos ) const -{ - SyntaxTreeIndexIterator startIndex = m_index.find( startPos ); - if (startIndex == m_index.end() ) - return m_emptyNode; - - SyntaxTreeIndexIterator2 endIndex = startIndex->second.find( endPos ); - if (endIndex == startIndex->second.end()) - return m_emptyNode; - - return endIndex->second; -} - -// for printing out tree -std::string SyntaxTree::ToString() const -{ - std::stringstream out; - out << *this; - return out.str(); -} - -void SyntaxTree::AddDefaultNonTerms(size_t phraseSize) -{ - for (size_t startPos = 0; startPos <= phraseSize; ++startPos) - { - for (size_t endPos = startPos; endPos < phraseSize; ++endPos) - { - AddNode(startPos, endPos, "X"); - } + if (add) { + labels.push_back(label); } } -void SyntaxTree::AddDefaultNonTerms(bool isSyntax, bool mixed, size_t phraseSize) +void SyntaxTree::AddToAll(const std::string &label) { - if (isSyntax) - { - AddDefaultNonTerms(!mixed, phraseSize); - } - else - { // add X everywhere - AddDefaultNonTerms(phraseSize); + Coll::iterator iter; + for (iter = m_coll.begin(); iter != m_coll.end(); ++iter) { + Labels &labels = iter->second; + labels.push_back(label); } } -void SyntaxTree::AddDefaultNonTerms(bool addEverywhere, size_t phraseSize) +const SyntaxTree::Labels &SyntaxTree::Find(int startPos, int endPos) const { - //cerr << "GetNumWords()=" << GetNumWords() << endl; - //assert(phraseSize == GetNumWords() || GetNumWords() == 1); // 1 if syntax sentence doesn't have any xml. TODO fix syntax tree obj - - for (size_t startPos = 0; startPos <= phraseSize; ++startPos) - { - for (size_t endPos = startPos; endPos <= phraseSize; ++endPos) - { - const SyntaxNodes &nodes = GetNodes(startPos, endPos); - if (!addEverywhere && nodes.size() > 0) - { // only add if no label - continue; - } - AddNode(startPos, endPos, "X"); - } - } + Coll::const_iterator iter; + iter = m_coll.find(Range(startPos, endPos)); + return (iter == m_coll.end()) ? m_defaultLabels : iter->second; } - -const SyntaxNodes SyntaxTree::GetNodesForLHS( int startPos, int endPos ) const -{ - SyntaxNodes ret(GetNodes(startPos, endPos)); - - if (ret.size() == 0) - ret.push_back(&m_defaultLHS); - - return ret; -} - -std::ostream& operator<<(std::ostream& os, const SyntaxTree& t) -{ - int size = t.m_index.size(); - for(size_t length=1; length<=size; length++) - { - for(size_t space=0; spaceGetLabel() + "#######"; - - os << label.substr(0,7) << " "; - } - else - { - os << "------- "; - } - } - os << std::endl; - } - return os; -} - - diff --git a/contrib/other-builds/extract-mixed-syntax/SyntaxTree.h b/contrib/other-builds/extract-mixed-syntax/SyntaxTree.h index 50a73a369..58f718151 100644 --- a/contrib/other-builds/extract-mixed-syntax/SyntaxTree.h +++ b/contrib/other-builds/extract-mixed-syntax/SyntaxTree.h @@ -1,96 +1,32 @@ #pragma once -// $Id: SyntaxTree.h 1960 2008-12-15 12:52:38Z phkoehn $ -// vim:tabstop=2 - -/*********************************************************************** - Moses - factored phrase-based language decoder - Copyright (C) 2009 University of Edinburgh - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - ***********************************************************************/ - -#include #include #include -#include +#include -class SyntaxNode; +class Parameter; -typedef std::vector SyntaxNodes; - -class SyntaxNode { -protected: - int m_start, m_end; - std::string m_label; - SyntaxNodes m_children; - SyntaxNode* m_parent; +class SyntaxTree +{ public: -SyntaxNode( int startPos, int endPos, const std::string &label) - :m_start(startPos) - ,m_end(endPos) - ,m_label(label) - {} - int GetStart() const - { return m_start; } - int GetEnd() const - { return m_end; } - const std::string &GetLabel() const - { return m_label; } - bool IsSyntax() const; + typedef std::pair Range; + typedef std::vector Labels; + typedef std::map Coll; + + void Add(int startPos, int endPos, const std::string &label, const Parameter ¶ms); + void AddToAll(const std::string &label); + + const Labels &Find(int startPos, int endPos) const; + + void SetHieroLabel(const std::string &label) { + m_defaultLabels.push_back(label); + } + + +protected: + + Coll m_coll; + Labels m_defaultLabels; }; -typedef std::vector< int > SplitPoints; -typedef std::vector< SplitPoints > ParentNodes; - -class SyntaxTree { -protected: - SyntaxNodes m_nodes; - SyntaxNode* m_top; - SyntaxNode m_defaultLHS; - - typedef std::map< int, SyntaxNodes > SyntaxTreeIndex2; - typedef SyntaxTreeIndex2::const_iterator SyntaxTreeIndexIterator2; - typedef std::map< int, SyntaxTreeIndex2 > SyntaxTreeIndex; - typedef SyntaxTreeIndex::const_iterator SyntaxTreeIndexIterator; - SyntaxTreeIndex m_index; - SyntaxNodes m_emptyNode; - - friend std::ostream& operator<<(std::ostream&, const SyntaxTree&); - -public: - SyntaxTree(); - ~SyntaxTree(); - - void AddNode( int startPos, int endPos, std::string label ); - ParentNodes Parse(); - bool HasNode( int startPos, int endPos ) const; - const SyntaxNodes &GetNodes( int startPos, int endPos ) const; - const SyntaxNodes &GetAllNodes() const { return m_nodes; } ; - size_t GetNumWords() const { return m_index.size(); } - std::string ToString() const; - - void AddDefaultNonTerms(bool isSyntax, bool addEverywhere, size_t phraseSize); - void AddDefaultNonTerms(bool mixed, size_t phraseSize); - - void AddDefaultNonTerms(size_t phraseSize); - - const SyntaxNodes GetNodesForLHS( int startPos, int endPos ) const; - -}; - -std::ostream& operator<<(std::ostream&, const SyntaxTree&); - diff --git a/contrib/other-builds/extract-mixed-syntax/Tunnel.cpp b/contrib/other-builds/extract-mixed-syntax/Tunnel.cpp deleted file mode 100644 index fc4846c34..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Tunnel.cpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Tunnel.cpp - * extract - * - * Created by Hieu Hoang on 19/01/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ - -#include "Tunnel.h" - - -int Tunnel::Compare(const Tunnel &other) const -{ - int ret = m_sourceRange.Compare(other.m_sourceRange); - - if (ret != 0) - return ret; - - ret = m_targetRange.Compare(other.m_targetRange); - - return ret; -} - -int Tunnel::Compare(const Tunnel &other, size_t direction) const -{ - const Range &thisRange = (direction == 0) ? m_sourceRange : m_targetRange; - const Range &otherRange = (direction == 0) ? other.m_sourceRange : other.m_targetRange; - - int ret = thisRange.Compare(otherRange); - return ret; -} - -std::ostream& operator<<(std::ostream &out, const Tunnel &tunnel) -{ - out << tunnel.m_sourceRange << "==>" << tunnel.m_targetRange; - return out; -} diff --git a/contrib/other-builds/extract-mixed-syntax/Tunnel.h b/contrib/other-builds/extract-mixed-syntax/Tunnel.h deleted file mode 100644 index 2659cca4a..000000000 --- a/contrib/other-builds/extract-mixed-syntax/Tunnel.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -/* - * Tunnel.h - * extract - * - * Created by Hieu Hoang on 19/01/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include -#include -#include -#include "Range.h" - - // for unaligned source terminal - -class Tunnel -{ - friend std::ostream& operator<<(std::ostream&, const Tunnel&); - -protected: - - Range m_sourceRange, m_targetRange; - -public: - Tunnel() - {} - - Tunnel(const Tunnel ©) - :m_sourceRange(copy.m_sourceRange) - ,m_targetRange(copy.m_targetRange) - {} - - Tunnel(const Range &sourceRange, const Range &targetRange) - :m_sourceRange(sourceRange) - ,m_targetRange(targetRange) - {} - - const Range &GetRange(size_t direction) const - { return (direction == 0) ? m_sourceRange : m_targetRange; } - - int Compare(const Tunnel &other) const; - int Compare(const Tunnel &other, size_t direction) const; -}; - -typedef std::vector TunnelList; - diff --git a/contrib/other-builds/extract-mixed-syntax/TunnelCollection.cpp b/contrib/other-builds/extract-mixed-syntax/TunnelCollection.cpp deleted file mode 100644 index 228cc3070..000000000 --- a/contrib/other-builds/extract-mixed-syntax/TunnelCollection.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* - * TunnelCollection.cpp - * extract - * - * Created by Hieu Hoang on 19/01/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ - -#include "TunnelCollection.h" -#include "Range.h" - -using namespace std; - -size_t TunnelCollection::NumUnalignedWord(size_t direction, size_t startPos, size_t endPos) const -{ - assert(startPos <= endPos); - - if (direction == 0) - assert(endPos < alignedCountS.size()); - else - assert(endPos < alignedCountT.size()); - - size_t ret = 0; - for (size_t ind = startPos; ind <= endPos; ++ind) - { - if (direction == 0 && alignedCountS[ind] == 0) - { - ret++; - } - else if (direction == 1 && alignedCountT[ind] == 0) - { - ret++; - } - - } - - return ret; -} - -void TunnelCollection::Add(int startS, int endS, int startT, int endT) -{ - // m_phraseExist[startS][endS - startS].push_back(Tunnel(startT, endT)); - m_coll[startS][endS - startS].push_back(Tunnel(Range(startS, endS), Range(startT, endT))); -} - - -std::ostream& operator<<(std::ostream &out, const TunnelCollection &TunnelCollection) -{ - size_t size = TunnelCollection.GetSize(); - - for (size_t startPos = 0; startPos < size; ++startPos) - { - for (size_t endPos = startPos; endPos < size; ++endPos) - { - const TunnelList &tunnelList = TunnelCollection.GetTunnels(startPos, endPos); - TunnelList::const_iterator iter; - for (iter = tunnelList.begin(); iter != tunnelList.end(); ++iter) - { - const Tunnel &tunnel = *iter; - out << tunnel << " "; - - } - } - } - - return out; -} - - diff --git a/contrib/other-builds/extract-mixed-syntax/TunnelCollection.h b/contrib/other-builds/extract-mixed-syntax/TunnelCollection.h deleted file mode 100644 index 547cbf814..000000000 --- a/contrib/other-builds/extract-mixed-syntax/TunnelCollection.h +++ /dev/null @@ -1,61 +0,0 @@ -#pragma once -/* - * TunnelCollection.h - * extract - * - * Created by Hieu Hoang on 19/01/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ -#include -#include "Tunnel.h" - -// reposity of extracted phrase pairs -// which are potential tunnels in larger phrase pairs -class TunnelCollection - { - friend std::ostream& operator<<(std::ostream&, const TunnelCollection&); - - protected: - std::vector< std::vector > m_coll; - // indexed by source pos. and source length - // maps to list of tunnels where are target pos - - public: - std::vector alignedCountS, alignedCountT; - - TunnelCollection(const TunnelCollection &); - - TunnelCollection(size_t size) - :m_coll(size) - { - // size is the length of the source sentence - for (size_t pos = 0; pos < size; ++pos) - { - // create empty tunnel lists - std::vector &endVec = m_coll[pos]; - endVec.resize(size - pos); - } - } - - void Add(int startS, int endS, int startT, int endT); - - //const TunnelList &GetTargetHoles(int startS, int endS) const - //{ - // const TunnelList &targetHoles = m_phraseExist[startS][endS - startS]; - // return targetHoles; - //} - const TunnelList &GetTunnels(int startS, int endS) const - { - const TunnelList &sourceHoles = m_coll[startS][endS - startS]; - return sourceHoles; - } - - const size_t GetSize() const - { return m_coll.size(); } - - size_t NumUnalignedWord(size_t direction, size_t startPos, size_t endPos) const; - - - }; - diff --git a/contrib/other-builds/extract-mixed-syntax/Word.cpp b/contrib/other-builds/extract-mixed-syntax/Word.cpp new file mode 100644 index 000000000..691266874 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Word.cpp @@ -0,0 +1,56 @@ +/* + * Word.cpp + * + * Created on: 18 Feb 2014 + * Author: s0565741 + */ +#include +#include "Word.h" + +using namespace std; + +Word::Word(int pos, const std::string &str) +:m_pos(pos) +,m_str(str) +{ + // TODO Auto-generated constructor stub + +} + +Word::~Word() { + // TODO Auto-generated destructor stub +} + +void Word::AddAlignment(const Word *other) +{ + m_alignment.insert(other); +} + +std::set Word::GetAlignmentIndex() const +{ + std::set ret; + + std::set::const_iterator iter; + for (iter = m_alignment.begin(); iter != m_alignment.end(); ++iter) { + const Word &otherWord = **iter; + int otherPos = otherWord.GetPos(); + ret.insert(otherPos); + } + + return ret; +} + +void Word::Output(std::ostream &out) const +{ + out << m_str; +} + +std::string Word::Debug() const +{ + return m_str; +} + +int Word::CompareString(const Word &other) const +{ + return m_str.compare(other.m_str); +} diff --git a/contrib/other-builds/extract-mixed-syntax/Word.h b/contrib/other-builds/extract-mixed-syntax/Word.h new file mode 100644 index 000000000..2f4600166 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/Word.h @@ -0,0 +1,47 @@ +/* + * Word.h + * + * Created on: 18 Feb 2014 + * Author: s0565741 + */ +#pragma once + +#include +#include +#include "RuleSymbol.h" + +// a terminal +class Word : public RuleSymbol +{ +public: + Word(const Word&); // do not implement + Word(int pos, const std::string &str); + virtual ~Word(); + + virtual bool IsNonTerm() const + { return false; } + + std::string GetString() const + { return m_str; } + + int GetPos() const + { return m_pos; } + + void AddAlignment(const Word *other); + + const std::set &GetAlignment() const + { return m_alignment; } + + std::set GetAlignmentIndex() const; + + void Output(std::ostream &out) const; + std::string Debug() const; + + int CompareString(const Word &other) const; + +protected: + int m_pos; // original position in sentence, NOT in lattice + std::string m_str; + std::set m_alignment; +}; + diff --git a/contrib/other-builds/extract-mixed-syntax/XmlTree.cpp b/contrib/other-builds/extract-mixed-syntax/XmlTree.cpp deleted file mode 100644 index 9145c8d1c..000000000 --- a/contrib/other-builds/extract-mixed-syntax/XmlTree.cpp +++ /dev/null @@ -1,344 +0,0 @@ -// $Id: XmlOption.cpp 1960 2008-12-15 12:52:38Z phkoehn $ -// vim:tabstop=2 - -/*********************************************************************** - Moses - factored phrase-based language decoder - Copyright (C) 2006 University of Edinburgh - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - ***********************************************************************/ - -#include -#include -#include -#include -#include -#include "SyntaxTree.h" - -using namespace std; - - -inline std::vector Tokenize(const std::string& str, - const std::string& delimiters = " \t") -{ - std::vector tokens; - // Skip delimiters at beginning. - std::string::size_type lastPos = str.find_first_not_of(delimiters, 0); - // Find first "non-delimiter". - std::string::size_type pos = str.find_first_of(delimiters, lastPos); - - while (std::string::npos != pos || std::string::npos != lastPos) - { - // Found a token, add it to the vector. - tokens.push_back(str.substr(lastPos, pos - lastPos)); - // Skip delimiters. Note the "not_of" - lastPos = str.find_first_not_of(delimiters, pos); - // Find next "non-delimiter" - pos = str.find_first_of(delimiters, lastPos); - } - - return tokens; -} - -const std::string Trim(const std::string& str, const std::string dropChars = " \t\n\r") -{ - std::string res = str; - res.erase(str.find_last_not_of(dropChars)+1); - return res.erase(0, res.find_first_not_of(dropChars)); -} - -string ParseXmlTagAttribute(const string& tag,const string& attributeName){ - /*TODO deal with unescaping \"*/ - string tagOpen = attributeName + "=\""; - size_t contentsStart = tag.find(tagOpen); - if (contentsStart == string::npos) return ""; - contentsStart += tagOpen.size(); - size_t contentsEnd = tag.find_first_of('"',contentsStart+1); - if (contentsEnd == string::npos) { - cerr << "Malformed XML attribute: "<< tag; - return ""; - } - size_t possibleEnd; - while (tag.at(contentsEnd-1) == '\\' && (possibleEnd = tag.find_first_of('"',contentsEnd+1)) != string::npos) { - contentsEnd = possibleEnd; - } - return tag.substr(contentsStart,contentsEnd-contentsStart); -} - -/** - * Remove "<" and ">" from XML tag - * - * \param str xml token to be stripped - */ -string TrimXml(const string& str) -{ - // too short to be xml token -> do nothing - if (str.size() < 2) return str; - - // strip first and last character - if (str[0] == '<' && str[str.size() - 1] == '>') - { - return str.substr(1, str.size() - 2); - } - // not an xml token -> do nothing - else { return str; } -} - -/** - * Check if the token is an XML tag, i.e. starts with "<" - * - * \param tag token to be checked - */ -bool isXmlTag(const string& tag) -{ - return tag[0] == '<'; -} - -/** - * Split up the input character string into tokens made up of - * either XML tags or text. - * example: this is a test . - * => (this ), (), ( is a ), (), ( test .) - * - * \param str input string - */ -inline vector TokenizeXml(const string& str) -{ - string lbrack = "<"; - string rbrack = ">"; - vector tokens; // vector of tokens to be returned - string::size_type cpos = 0; // current position in string - string::size_type lpos = 0; // left start of xml tag - string::size_type rpos = 0; // right end of xml tag - - // walk thorugh the string (loop vver cpos) - while (cpos != str.size()) - { - // find the next opening "<" of an xml tag - lpos = str.find_first_of(lbrack, cpos); - if (lpos != string::npos) - { - // find the end of the xml tag - rpos = str.find_first_of(rbrack, lpos); - // sanity check: there has to be closing ">" - if (rpos == string::npos) - { - cerr << "ERROR: malformed XML: " << str << endl; - return tokens; - } - } - else // no more tags found - { - // add the rest as token - tokens.push_back(str.substr(cpos)); - break; - } - - // add stuff before xml tag as token, if there is any - if (lpos - cpos > 0) - tokens.push_back(str.substr(cpos, lpos - cpos)); - - // add xml tag as token - tokens.push_back(str.substr(lpos, rpos-lpos+1)); - cpos = rpos + 1; - } - return tokens; -} - -/** - * Process a sentence with xml annotation - * Xml tags may specifiy additional/replacing translation options - * and reordering constraints - * - * \param line in: sentence, out: sentence without the xml - * \param res vector with translation options specified by xml - * \param reorderingConstraint reordering constraint zones specified by xml - * \param walls reordering constraint walls specified by xml - */ -/*TODO: we'd only have to return a vector of XML options if we dropped linking. 2-d vector - is so we can link things up afterwards. We can't create TranslationOptions as we - parse because we don't have the completed source parsed until after this function - removes all the markup from it (CreateFromString in Sentence::Read). -*/ -bool ProcessAndStripXMLTags(string &line, SyntaxTree &tree, set< string > &labelCollection, map< string, int > &topLabelCollection ) { - //parse XML markup in translation line - - // no xml tag? we're done. - if (line.find_first_of('<') == string::npos) { return true; } - - // break up input into a vector of xml tags and text - // example: (this), (), (is a), (), (test .) - vector xmlTokens = TokenizeXml(line); - - // we need to store opened tags, until they are closed - // tags are stored as tripled (tagname, startpos, contents) - typedef pair< string, pair< size_t, string > > OpenedTag; - vector< OpenedTag > tagStack; // stack that contains active opened tags - - string cleanLine; // return string (text without xml) - size_t wordPos = 0; // position in sentence (in terms of number of words) - bool isLinked = false; - - // loop through the tokens - for (size_t xmlTokenPos = 0 ; xmlTokenPos < xmlTokens.size() ; xmlTokenPos++) - { - // not a xml tag, but regular text (may contain many words) - if(!isXmlTag(xmlTokens[xmlTokenPos])) - { - // add a space at boundary, if necessary - if (cleanLine.size()>0 && - cleanLine[cleanLine.size() - 1] != ' ' && - xmlTokens[xmlTokenPos][0] != ' ') - { - cleanLine += " "; - } - cleanLine += xmlTokens[xmlTokenPos]; // add to output - wordPos = Tokenize(cleanLine).size(); // count all the words - } - - // process xml tag - else - { - // *** get essential information about tag *** - - // strip extra boundary spaces and "<" and ">" - string tag = Trim(TrimXml(xmlTokens[xmlTokenPos])); - // cerr << "XML TAG IS: " << tag << std::endl; - - if (tag.size() == 0) - { - cerr << "ERROR: empty tag name: " << line << endl; - return false; - } - - // check if unary (e.g., "") - bool isUnary = ( tag[tag.size() - 1] == '/' ); - - // check if opening tag (e.g. "", not "")g - bool isClosed = ( tag[0] == '/' ); - bool isOpen = !isClosed; - - if (isClosed && isUnary) - { - cerr << "ERROR: can't have both closed and unary tag <" << tag << ">: " << line << endl; - return false; - } - - if (isClosed) - tag = tag.substr(1); // remove "/" at the beginning - if (isUnary) - tag = tag.substr(0,tag.size()-1); // remove "/" at the end - - // find the tag name and contents - string::size_type endOfName = tag.find_first_of(' '); - string tagName = tag; - string tagContent = ""; - if (endOfName != string::npos) { - tagName = tag.substr(0,endOfName); - tagContent = tag.substr(endOfName+1); - } - - // *** process new tag *** - - if (isOpen || isUnary) - { - // put the tag on the tag stack - OpenedTag openedTag = make_pair( tagName, make_pair( wordPos, tagContent ) ); - tagStack.push_back( openedTag ); - // cerr << "XML TAG " << tagName << " (" << tagContent << ") added to stack, now size " << tagStack.size() << endl; - } - - // *** process completed tag *** - - if (isClosed || isUnary) - { - // pop last opened tag from stack; - if (tagStack.size() == 0) - { - cerr << "ERROR: tag " << tagName << " closed, but not opened" << ":" << line << endl; - return false; - } - OpenedTag openedTag = tagStack.back(); - tagStack.pop_back(); - - // tag names have to match - if (openedTag.first != tagName) - { - cerr << "ERROR: tag " << openedTag.first << " closed by tag " << tagName << ": " << line << endl; - return false; - } - - // assemble remaining information about tag - size_t startPos = openedTag.second.first; - string tagContent = openedTag.second.second; - size_t endPos = wordPos; - - // span attribute overwrites position - string span = ParseXmlTagAttribute(tagContent,"span"); - if (! span.empty()) - { - vector ij = Tokenize(span, "-"); - if (ij.size() != 1 && ij.size() != 2) { - cerr << "ERROR: span attribute must be of the form \"i-j\" or \"i\": " << line << endl; - return false; - } - startPos = atoi(ij[0].c_str()); - if (ij.size() == 1) endPos = startPos + 1; - else endPos = atoi(ij[1].c_str()) + 1; - } - - // cerr << "XML TAG " << tagName << " (" << tagContent << ") spanning " << startPos << " to " << (endPos-1) << " complete, commence processing" << endl; - - if (startPos >= endPos) - { - cerr << "ERROR: tag " << tagName << " must span at least one word (" << startPos << "-" << endPos << "): " << line << endl; - return false; - } - - string label = ParseXmlTagAttribute(tagContent,"label"); - labelCollection.insert( label ); - - // report what we have processed so far - if (0) { - cerr << "XML TAG NAME IS: '" << tagName << "'" << endl; - cerr << "XML TAG LABEL IS: '" << label << "'" << endl; - cerr << "XML SPAN IS: " << startPos << "-" << (endPos-1) << endl; - } - tree.AddNode( startPos, endPos-1, label ); - } - } - } - // we are done. check if there are tags that are still open - if (tagStack.size() > 0) - { - cerr << "ERROR: some opened tags were never closed: " << line << endl; - return false; - } - - // collect top labels - const SyntaxNodes &topNodes = tree.GetNodes( 0, wordPos-1 ); - for( SyntaxNodes::const_iterator node = topNodes.begin(); node != topNodes.end(); node++ ) - { - const SyntaxNode *n = *node; - const string &label = n->GetLabel(); - if (topLabelCollection.find( label ) == topLabelCollection.end()) - topLabelCollection[ label ] = 0; - topLabelCollection[ label ]++; - } - - // return de-xml'ed sentence in line - line = cleanLine; - return true; -} diff --git a/contrib/other-builds/extract-mixed-syntax/XmlTree.h b/contrib/other-builds/extract-mixed-syntax/XmlTree.h deleted file mode 100644 index cd54b8f17..000000000 --- a/contrib/other-builds/extract-mixed-syntax/XmlTree.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -// $Id: XmlOption.cpp 1960 2008-12-15 12:52:38Z phkoehn $ -// vim:tabstop=2 - -/*********************************************************************** - Moses - factored phrase-based language decoder - Copyright (C) 2006 University of Edinburgh - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - ***********************************************************************/ - -#include -#include -#include -#include -#include "SyntaxTree.h" - -std::string ParseXmlTagAttribute(const std::string& tag,const std::string& attributeName); -std::string TrimXml(const std::string& str); -bool isXmlTag(const std::string& tag); -inline std::vector TokenizeXml(const std::string& str); -bool ProcessAndStripXMLTags(std::string &line, SyntaxTree &tree, std::set< std::string > &labelCollection, std::map< std::string, int > &topLabelCollection ); diff --git a/contrib/other-builds/extract-mixed-syntax/extract.cpp b/contrib/other-builds/extract-mixed-syntax/extract.cpp deleted file mode 100644 index 334a3e124..000000000 --- a/contrib/other-builds/extract-mixed-syntax/extract.cpp +++ /dev/null @@ -1,310 +0,0 @@ -// $Id: extract.cpp 2828 2010-02-01 16:07:58Z hieuhoang1972 $ -// vim:tabstop=2 - -/*********************************************************************** - Moses - factored phrase-based language decoder - Copyright (C) 2009 University of Edinburgh - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - ***********************************************************************/ - -#include -#include -#include -#include -#include -#include -#include -#include "extract.h" -#include "InputFileStream.h" -#include "OutputFileStream.h" -#include "Lattice.h" - -#ifdef WIN32 -// Include Visual Leak Detector -#include -#endif - -using namespace std; - -void writeGlueGrammar(const string &, Global &options, set< string > &targetLabelCollection, map< string, int > &targetTopLabelCollection); - -int main(int argc, char* argv[]) -{ - cerr << "Extract v2.0, written by Philipp Koehn\n" - << "rule extraction from an aligned parallel corpus\n"; - //time_t starttime = time(NULL); - - Global *global = new Global(); - g_global = global; - int sentenceOffset = 0; - - if (argc < 5) { - cerr << "syntax: extract-mixed-syntax corpus.target corpus.source corpus.align extract " - << " [ --Hierarchical | --Orientation" - << " | --GlueGrammar FILE | --UnknownWordLabel FILE" - << " | --OnlyDirect" - - << " | --MinHoleSpanSourceDefault[" << global->minHoleSpanSourceDefault << "]" - << " | --MaxHoleSpanSourceDefault[" << global->maxHoleSpanSourceDefault << "]" - << " | --MinHoleSpanSourceSyntax[" << global->minHoleSpanSourceSyntax << "]" - << " | --MaxHoleSpanSourceSyntax[" << global->maxHoleSpanSourceSyntax << "]" - - << " | --MaxSymbols[" << global->maxSymbols<< "]" - << " | --MaxNonTerm[" << global->maxNonTerm << "]" - << " | --SourceSyntax | --TargetSyntax" - << " | --UppermostOnly[" << g_global->uppermostOnly << "]" - << endl; - exit(1); - } - char* &fileNameT = argv[1]; - char* &fileNameS = argv[2]; - char* &fileNameA = argv[3]; - string fileNameGlueGrammar; - string fileNameUnknownWordLabel; - string fileNameExtract = string(argv[4]); - - int optionInd = 5; - - for(int i=optionInd;iminHoleSpanSourceDefault = atoi(argv[++i]); - if (global->minHoleSpanSourceDefault < 1) { - cerr << "extract error: --minHoleSourceDefault should be at least 1" << endl; - exit(1); - } - } - else if (strcmp(argv[i],"--MaxHoleSpanSourceDefault") == 0) { - global->maxHoleSpanSourceDefault = atoi(argv[++i]); - if (global->maxHoleSpanSourceDefault < 1) { - cerr << "extract error: --maxHoleSourceDefault should be at least 1" << endl; - exit(1); - } - } - else if (strcmp(argv[i],"--MinHoleSpanSourceSyntax") == 0) { - global->minHoleSpanSourceSyntax = atoi(argv[++i]); - if (global->minHoleSpanSourceSyntax < 1) { - cerr << "extract error: --minHoleSourceSyntax should be at least 1" << endl; - exit(1); - } - } - else if (strcmp(argv[i],"--UppermostOnly") == 0) { - global->uppermostOnly = atoi(argv[++i]); - } - else if (strcmp(argv[i],"--MaxHoleSpanSourceSyntax") == 0) { - global->maxHoleSpanSourceSyntax = atoi(argv[++i]); - if (global->maxHoleSpanSourceSyntax < 1) { - cerr << "extract error: --maxHoleSourceSyntax should be at least 1" << endl; - exit(1); - } - } - - // maximum number of words in hierarchical phrase - else if (strcmp(argv[i],"--maxSymbols") == 0) { - global->maxSymbols = atoi(argv[++i]); - if (global->maxSymbols < 1) { - cerr << "extract error: --maxSymbols should be at least 1" << endl; - exit(1); - } - } - // maximum number of non-terminals - else if (strcmp(argv[i],"--MaxNonTerm") == 0) { - global->maxNonTerm = atoi(argv[++i]); - if (global->maxNonTerm < 1) { - cerr << "extract error: --MaxNonTerm should be at least 1" << endl; - exit(1); - } - } - // allow consecutive non-terminals (X Y | X Y) - else if (strcmp(argv[i],"--TargetSyntax") == 0) { - global->targetSyntax = true; - } - else if (strcmp(argv[i],"--SourceSyntax") == 0) { - global->sourceSyntax = true; - } - // do not create many part00xx files! - else if (strcmp(argv[i],"--NoFileLimit") == 0) { - // now default - } - else if (strcmp(argv[i],"--GlueGrammar") == 0) { - global->glueGrammarFlag = true; - if (++i >= argc) - { - cerr << "ERROR: Option --GlueGrammar requires a file name" << endl; - exit(0); - } - fileNameGlueGrammar = string(argv[i]); - cerr << "creating glue grammar in '" << fileNameGlueGrammar << "'" << endl; - } - else if (strcmp(argv[i],"--UnknownWordLabel") == 0) { - global->unknownWordLabelFlag = true; - if (++i >= argc) - { - cerr << "ERROR: Option --UnknownWordLabel requires a file name" << endl; - exit(0); - } - fileNameUnknownWordLabel = string(argv[i]); - cerr << "creating unknown word labels in '" << fileNameUnknownWordLabel << "'" << endl; - } - // TODO: this should be a useful option - //else if (strcmp(argv[i],"--ZipFiles") == 0) { - // zipFiles = true; - //} - // if an source phrase is paired with two target phrases, then count(t|s) = 0.5 - else if (strcmp(argv[i],"--Mixed") == 0) { - global->mixed = true; - } - else if (strcmp(argv[i],"--AllowDefaultNonTermEdge") == 0) { - global->allowDefaultNonTermEdge = atoi(argv[++i]); - } - else if (strcmp(argv[i], "--GZOutput") == 0) { - global->gzOutput = true; - } - else if (strcmp(argv[i],"--MaxSpan") == 0) { - // ignore - ++i; - } - else if (strcmp(argv[i],"--SentenceOffset") == 0) { - if (i+1 >= argc || argv[i+1][0] < '0' || argv[i+1][0] > '9') { - cerr << "extract: syntax error, used switch --SentenceOffset without a number" << endl; - exit(1); - } - sentenceOffset = atoi(argv[++i]); - } - else { - cerr << "extract: syntax error, unknown option '" << string(argv[i]) << "'\n"; - exit(1); - } - } - - - // open input files - Moses::InputFileStream tFile(fileNameT); - Moses::InputFileStream sFile(fileNameS); - Moses::InputFileStream aFile(fileNameA); - - // open output files - string fileNameExtractInv = fileNameExtract + ".inv"; - if (global->gzOutput) { - fileNameExtract += ".gz"; - fileNameExtractInv += ".gz"; - } - - Moses::OutputFileStream extractFile; - Moses::OutputFileStream extractFileInv; - extractFile.Open(fileNameExtract.c_str()); - extractFileInv.Open(fileNameExtractInv.c_str()); - - - // loop through all sentence pairs - int i = sentenceOffset; - while(true) { - i++; - - if (i % 1000 == 0) { - cerr << i << " " << flush; - } - - string targetString; - string sourceString; - string alignmentString; - - bool ok = getline(tFile, targetString); - if (!ok) - break; - getline(sFile, sourceString); - getline(aFile, alignmentString); - - //cerr << endl << targetString << endl << sourceString << endl << alignmentString << endl; - - //time_t currTime = time(NULL); - //cerr << "A " << (currTime - starttime) << endl; - - SentenceAlignment sentencePair; - if (sentencePair.Create( targetString, sourceString, alignmentString, i, *global )) - { - //cerr << sentence.sourceTree << endl; - //cerr << sentence.targetTree << endl; - - sentencePair.FindTunnels(*g_global); - //cerr << "C " << (time(NULL) - starttime) << endl; - //cerr << sentencePair << endl; - - sentencePair.CreateLattice(*g_global); - //cerr << "D " << (time(NULL) - starttime) << endl; - //cerr << sentencePair << endl; - - sentencePair.CreateRules(*g_global); - //cerr << "E " << (time(NULL) - starttime) << endl; - - //cerr << sentence.lattice->GetRules().GetSize() << endl; - sentencePair.GetLattice().GetRules().Output(extractFile); - sentencePair.GetLattice().GetRules().OutputInv(extractFileInv); - } - } - - tFile.Close(); - sFile.Close(); - aFile.Close(); - - extractFile.Close(); - extractFileInv.Close(); - - if (global->glueGrammarFlag) { - writeGlueGrammar(fileNameGlueGrammar, *global, targetLabelCollection, targetTopLabelCollection); - } - - delete global; -} - - -void writeGlueGrammar( const string & fileName, Global &options, set< string > &targetLabelCollection, map< string, int > &targetTopLabelCollection ) -{ - ofstream grammarFile; - grammarFile.open(fileName.c_str()); - if (!options.targetSyntax) { - grammarFile << " [X] ||| [S] ||| 1 ||| ||| 0" << endl - << "[X][S] [X] ||| [X][S] [S] ||| 1 ||| 0-0 ||| 0" << endl - << "[X][S] [X][X] [X] ||| [X][S] [X][X] [S] ||| 2.718 ||| 0-0 1-1 ||| 0" << endl; - } else { - // chose a top label that is not already a label - string topLabel = "QQQQQQ"; - for( unsigned int i=1; i<=topLabel.length(); i++) { - if(targetLabelCollection.find( topLabel.substr(0,i) ) == targetLabelCollection.end() ) { - topLabel = topLabel.substr(0,i); - break; - } - } - // basic rules - grammarFile << " [X] ||| [" << topLabel << "] ||| 1 ||| " << endl - << "[X][" << topLabel << "] [X] ||| [X][" << topLabel << "] [" << topLabel << "] ||| 1 ||| 0-0 " << endl; - - // top rules - for( map::const_iterator i = targetTopLabelCollection.begin(); - i != targetTopLabelCollection.end(); i++ ) { - grammarFile << " [X][" << i->first << "] [X] ||| [X][" << i->first << "] [" << topLabel << "] ||| 1 ||| 1-1" << endl; - } - - // glue rules - for( set::const_iterator i = targetLabelCollection.begin(); - i != targetLabelCollection.end(); i++ ) { - grammarFile << "[X][" << topLabel << "] [X][" << *i << "] [X] ||| [X][" << topLabel << "] [X][" << *i << "] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1" << endl; - } - grammarFile << "[X][" << topLabel << "] [X][X] [X] ||| [X][" << topLabel << "] [X][X] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 " << endl; // glue rule for unknown word... - } - grammarFile.close(); -} - diff --git a/contrib/other-builds/extract-mixed-syntax/extract.h b/contrib/other-builds/extract-mixed-syntax/extract.h deleted file mode 100644 index ac831f2d9..000000000 --- a/contrib/other-builds/extract-mixed-syntax/extract.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include "SyntaxTree.h" -#include "XmlTree.h" -#include "Tunnel.h" -#include "TunnelCollection.h" -#include "SentenceAlignment.h" -#include "Global.h" - -std::vector tokenize( const char [] ); - -#define SAFE_GETLINE(_IS, _LINE, _SIZE, _DELIM) { \ - _IS.getline(_LINE, _SIZE, _DELIM); \ - if(_IS.fail() && !_IS.bad() && !_IS.eof()) _IS.clear(); \ - if (_IS.gcount() == _SIZE-1) { \ - cerr << "Line too long! Buffer overflow. Delete lines >=" \ - << _SIZE << " chars or raise LINE_MAX_LENGTH in phrase-extract/extract.cpp" \ - << endl; \ - exit(1); \ - } \ - } -#define LINE_MAX_LENGTH 1000000 - -const Global *g_global; - -std::set< std::string > targetLabelCollection, sourceLabelCollection; -std::map< std::string, int > targetTopLabelCollection, sourceTopLabelCollection; diff --git a/contrib/other-builds/extract-mixed-syntax/filter-by-source-word-count.perl b/contrib/other-builds/extract-mixed-syntax/filter-by-source-word-count.perl new file mode 100755 index 000000000..d0e482a02 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/filter-by-source-word-count.perl @@ -0,0 +1,27 @@ +#!/usr/bin/perl + +use strict; + +binmode(STDIN, ":utf8"); +binmode(STDOUT, ":utf8"); +binmode(STDERR, ":utf8"); + +my $maxNumWords = $ARGV[0]; + +while (my $line = ) { + chomp($line); + my @toks = split(/ /,$line); + + my $numSourceWords = 0; + my $tok = $toks[$numSourceWords]; + while ($tok ne "|||") { + ++$numSourceWords; + $tok = $toks[$numSourceWords]; + } + + if ($numSourceWords <= $maxNumWords) { + print "$line\n"; + } +} + + diff --git a/contrib/other-builds/extract-mixed-syntax/learnable/equal.perl b/contrib/other-builds/extract-mixed-syntax/learnable/equal.perl new file mode 100755 index 000000000..e43b48a84 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/learnable/equal.perl @@ -0,0 +1,33 @@ +#! /usr/bin/perl -w + +use strict; + +sub trim($); + +my $file1 = $ARGV[0]; +my $file2 = $ARGV[1]; + +open (FILE1, $file1); +open (FILE2, $file2); + +my $countEqual = 0; +while (my $line1 = ) { + my $line2 = ; + if (trim($line1) eq trim($line2)) { + ++$countEqual; + } +} + +print $countEqual ."\n"; + + +###################### +# Perl trim function to remove whitespace from the start and end of the string +sub trim($) { + my $string = shift; + $string =~ s/^\s+//; + $string =~ s/\s+$//; + return $string; +} + + diff --git a/contrib/other-builds/extract-mixed-syntax/learnable/get-by-line-number.perl b/contrib/other-builds/extract-mixed-syntax/learnable/get-by-line-number.perl new file mode 100755 index 000000000..f9ec9e39b --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/learnable/get-by-line-number.perl @@ -0,0 +1,29 @@ +#! /usr/bin/perl -w + +use strict; + +binmode(STDIN, ":utf8"); +binmode(STDOUT, ":utf8"); +binmode(STDERR, ":utf8"); + +my $fileLineNum = $ARGV[0]; +open (FILE_LINE_NUM, $fileLineNum); + +my $nextLineNum = ; + +my $lineNum = 1; +while (my $line = ) { + if (defined($nextLineNum) && $lineNum == $nextLineNum) { + # matches. output line + chomp($line); + print "$line\n"; + + # next line number + $nextLineNum = ; + } + + ++$lineNum; +} + + + diff --git a/contrib/other-builds/extract-mixed-syntax/learnable/learnable.perl b/contrib/other-builds/extract-mixed-syntax/learnable/learnable.perl new file mode 100755 index 000000000..6edcff3f9 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/learnable/learnable.perl @@ -0,0 +1,108 @@ +#! /usr/bin/perl -w + +use strict; + +my $iniPath = $ARGV[0]; +my $isHiero = $ARGV[1]; +my $decoderExec = $ARGV[2]; +my $extractExec = $ARGV[3]; +my $tmpName = $ARGV[4]; + +my $WORK_DIR = `pwd`; +chomp($WORK_DIR); + +my $MOSES_DIR = "~/workspace/github/mosesdecoder.hieu"; + +$decoderExec = "$MOSES_DIR/bin/$decoderExec"; +$extractExec = "$MOSES_DIR/bin/$extractExec"; + +my $SPLIT_EXEC = `gsplit --help 2>/dev/null`; +if($SPLIT_EXEC) { + $SPLIT_EXEC = 'gsplit'; +} +else { + $SPLIT_EXEC = 'split'; +} + +my $SORT_EXEC = `gsort --help 2>/dev/null`; +if($SORT_EXEC) { + $SORT_EXEC = 'gsort'; +} +else { + $SORT_EXEC = 'sort'; +} + + +my $hieroFlag = ""; +if ($isHiero == 1) { + $hieroFlag = "--Hierarchical"; +} + +print STDERR "WORK_DIR=$WORK_DIR \n"; + +my $cmd; + +open (SOURCE, "source"); +open (TARGET, "target"); +open (ALIGNMENT, "alignment"); + +my $lineNum = 0; +my ($source, $target, $alignment); +while ($source = ) { + chomp($source); + $target = ; chomp($target); + $alignment = ; chomp($alignment); + + #print STDERR "$source ||| $target ||| $alignment \n"; + + # write out 1 line + my $tmpDir = "$WORK_DIR/$tmpName/work$lineNum"; + `mkdir -p $tmpDir`; + + open (SOURCE1, ">$tmpDir/source"); + open (TARGET1, ">$tmpDir/target"); + open (ALIGNMENT1, ">$tmpDir/alignment"); + + print SOURCE1 "$source\n"; + print TARGET1 "$target\n"; + print ALIGNMENT1 "$alignment\n"; + + close (SOURCE1); + close (TARGET1); + close (ALIGNMENT1); + + # train + if ($isHiero == 1) { + $cmd = "$extractExec $tmpDir/target $tmpDir/source $tmpDir/alignment $tmpDir/extract --GZOutput"; + } + else { + # pb + $cmd = "$extractExec $tmpDir/target $tmpDir/source $tmpDir/alignment $tmpDir/extract 7 --GZOutput"; + } + $cmd = "$MOSES_DIR/scripts/generic/extract-parallel.perl 1 $SPLIT_EXEC $SORT_EXEC $cmd"; + print STDERR "Executing: $cmd\n"; + `$cmd`; + + $cmd = "$MOSES_DIR/scripts/generic/score-parallel.perl 1 $SORT_EXEC $MOSES_DIR/bin/score $tmpDir/extract.sorted.gz /dev/null $tmpDir/pt.half.gz $hieroFlag --NoLex 1"; + `$cmd`; + + $cmd = "$MOSES_DIR/scripts/generic/score-parallel.perl 1 $SORT_EXEC $MOSES_DIR/bin/score $tmpDir/extract.inv.sorted.gz /dev/null $tmpDir/pt.half.inv.gz --Inverse $hieroFlag --NoLex 1"; + `$cmd`; + + $cmd = "$MOSES_DIR/bin/consolidate $tmpDir/pt.half.gz $tmpDir/pt.half.inv.gz $tmpDir/pt $hieroFlag --OnlyDirect"; + `$cmd`; + + # decode + $cmd = "$decoderExec -f $iniPath -feature-overwrite \"TranslationModel0 path=$tmpDir/pt\" -i $tmpDir/source -feature-add \"ConstrainedDecoding path=$tmpDir/target\""; + print STDERR "Executing: $cmd\n"; + `$cmd`; + +# `rm -rf $tmpDir`; + + ++$lineNum; +} + +close(SOURCE); +close(TARGET); +close(ALIGNMENT); + diff --git a/contrib/other-builds/extract-mixed-syntax/learnable/num-deriv.perl b/contrib/other-builds/extract-mixed-syntax/learnable/num-deriv.perl new file mode 100755 index 000000000..5d66d5505 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/learnable/num-deriv.perl @@ -0,0 +1,151 @@ +#! /usr/bin/perl -w + +use strict; + +sub Write1Line; +sub WriteCorpus1Holdout; + +my $iniPath = $ARGV[0]; +my $isHiero = $ARGV[1]; +my $decoderExec = $ARGV[2]; +my $extractExec = $ARGV[3]; +my $tmpName = $ARGV[4]; +my $startLine = $ARGV[5]; +my $endLine = $ARGV[6]; + +print STDERR "iniPath=$iniPath \n isHiero=$isHiero \n decoderExec=$decoderExec \n extractExec=$extractExec \n"; + +my $WORK_DIR = `pwd`; +chomp($WORK_DIR); + +my $MOSES_DIR = "~/workspace/github/mosesdecoder.hieu.gna"; + +$decoderExec = "$MOSES_DIR/bin/$decoderExec"; +$extractExec = "$MOSES_DIR/bin/$extractExec"; + +my $SPLIT_EXEC = `gsplit --help 2>/dev/null`; +if($SPLIT_EXEC) { + $SPLIT_EXEC = 'gsplit'; +} +else { + $SPLIT_EXEC = 'split'; +} + +my $SORT_EXEC = `gsort --help 2>/dev/null`; +if($SORT_EXEC) { + $SORT_EXEC = 'gsort'; +} +else { + $SORT_EXEC = 'sort'; +} + + +my $hieroFlag = ""; +if ($isHiero == 1) { + $hieroFlag = "--Hierarchical"; +} + +print STDERR "WORK_DIR=$WORK_DIR \n"; + +my $cmd; + +open (SOURCE, "source"); +open (TARGET, "target"); +open (ALIGNMENT, "alignment"); + +my $numLines = `cat source | wc -l`; + +for (my $lineNum = 0; $lineNum < $numLines; ++$lineNum) { + my $source = ; chomp($source); + my $target = ; chomp($target); + my $alignment = ; chomp($alignment); + + if ($lineNum < $startLine || $lineNum >= $endLine) { + next; + } + + #print STDERR "$source ||| $target ||| $alignment \n"; + # write out 1 line + my $tmpDir = "$WORK_DIR/$tmpName/work$lineNum"; + `mkdir -p $tmpDir`; + + Write1Line($source, $tmpDir, "source.1"); + Write1Line($target, $tmpDir, "target.1"); + Write1Line($alignment, $tmpDir, "alignment.1"); + + WriteCorpus1Holdout($lineNum, "source", $tmpDir, "source.corpus"); + WriteCorpus1Holdout($lineNum, "target", $tmpDir, "target.corpus"); + WriteCorpus1Holdout($lineNum, "alignment", $tmpDir, "alignment.corpus"); + + # train + if ($isHiero == 1) { + $cmd = "$extractExec $tmpDir/target.corpus $tmpDir/source.corpus $tmpDir/alignment.corpus $tmpDir/extract --GZOutput"; + } + else { + # pb + $cmd = "$extractExec $tmpDir/target.corpus $tmpDir/source.corpus $tmpDir/alignment.corpus $tmpDir/extract 7 --GZOutput"; + } + $cmd = "$MOSES_DIR/scripts/generic/extract-parallel.perl 1 $SPLIT_EXEC $SORT_EXEC $cmd"; + print STDERR "Executing: $cmd\n"; + `$cmd`; + + $cmd = "$MOSES_DIR/scripts/generic/score-parallel.perl 1 $SORT_EXEC $MOSES_DIR/bin/score $tmpDir/extract.sorted.gz /dev/null $tmpDir/pt.half.gz $hieroFlag --NoLex 1"; + `$cmd`; + + $cmd = "$MOSES_DIR/scripts/generic/score-parallel.perl 1 $SORT_EXEC $MOSES_DIR/bin/score $tmpDir/extract.inv.sorted.gz /dev/null $tmpDir/pt.half.inv.gz --Inverse $hieroFlag --NoLex 1"; + `$cmd`; + + $cmd = "$MOSES_DIR/bin/consolidate $tmpDir/pt.half.gz $tmpDir/pt.half.inv.gz $tmpDir/pt $hieroFlag --OnlyDirect"; + `$cmd`; + + # decode + $cmd = "$decoderExec -f $iniPath -feature-overwrite \"TranslationModel0 path=$tmpDir/pt\" -i $tmpDir/source.1 -n-best-list $tmpDir/nbest 10000 distinct -v 2"; + print STDERR "Executing: $cmd\n"; + `$cmd`; + + # count the number of translation in nbest list + $cmd = "wc -l $tmpDir/nbest >> out"; + `$cmd`; + + `rm -rf $tmpDir`; +} + +close(SOURCE); +close(TARGET); +close(ALIGNMENT); + + +###################### +sub Write1Line +{ + my ($line, $tmpDir, $fileName) = @_; + + open (HANDLE, ">$tmpDir/$fileName"); + print HANDLE "$line\n"; + close (HANDLE); +} + +sub WriteCorpus1Holdout +{ + my ($holdoutLineNum, $inFilePath, $tmpDir, $outFileName) = @_; + + open (INFILE, "$inFilePath"); + open (OUTFILE, ">$tmpDir/$outFileName"); + + my $lineNum = 0; + while (my $line = ) { + chomp($line); + + if ($lineNum != $holdoutLineNum) { + print OUTFILE "$line\n"; + } + + ++$lineNum; + } + + close (OUTFILE); + close(INFILE); + +} + + diff --git a/contrib/other-builds/extract-mixed-syntax/learnable/reachable.perl b/contrib/other-builds/extract-mixed-syntax/learnable/reachable.perl new file mode 100755 index 000000000..14432f5a7 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/learnable/reachable.perl @@ -0,0 +1,147 @@ +#! /usr/bin/perl -w + +use strict; + +sub Write1Line; +sub WriteCorpus1Holdout; + +my $iniPath = $ARGV[0]; +my $isHiero = $ARGV[1]; +my $decoderExec = $ARGV[2]; +my $extractExec = $ARGV[3]; +my $tmpName = $ARGV[4]; +my $startLine = $ARGV[5]; +my $endLine = $ARGV[6]; + +print STDERR "iniPath=$iniPath \n isHiero=$isHiero \n decoderExec=$decoderExec \n extractExec=$extractExec \n"; + +my $WORK_DIR = `pwd`; +chomp($WORK_DIR); + +my $MOSES_DIR = "~/workspace/github/mosesdecoder.hieu.gna"; + +$decoderExec = "$MOSES_DIR/bin/$decoderExec"; +$extractExec = "$MOSES_DIR/bin/$extractExec"; + +my $SPLIT_EXEC = `gsplit --help 2>/dev/null`; +if($SPLIT_EXEC) { + $SPLIT_EXEC = 'gsplit'; +} +else { + $SPLIT_EXEC = 'split'; +} + +my $SORT_EXEC = `gsort --help 2>/dev/null`; +if($SORT_EXEC) { + $SORT_EXEC = 'gsort'; +} +else { + $SORT_EXEC = 'sort'; +} + + +my $hieroFlag = ""; +if ($isHiero == 1) { + $hieroFlag = "--Hierarchical"; +} + +print STDERR "WORK_DIR=$WORK_DIR \n"; + +my $cmd; + +open (SOURCE, "source"); +open (TARGET, "target"); +open (ALIGNMENT, "alignment"); + +my $numLines = `cat source | wc -l`; + +for (my $lineNum = 0; $lineNum < $numLines; ++$lineNum) { + my $source = ; chomp($source); + my $target = ; chomp($target); + my $alignment = ; chomp($alignment); + + if ($lineNum < $startLine || $lineNum >= $endLine) { + next; + } + + #print STDERR "$source ||| $target ||| $alignment \n"; + # write out 1 line + my $tmpDir = "$WORK_DIR/$tmpName/work$lineNum"; + `mkdir -p $tmpDir`; + + Write1Line($source, $tmpDir, "source.1"); + Write1Line($target, $tmpDir, "target.1"); + Write1Line($alignment, $tmpDir, "alignment.1"); + + WriteCorpus1Holdout($lineNum, "source", $tmpDir, "source.corpus"); + WriteCorpus1Holdout($lineNum, "target", $tmpDir, "target.corpus"); + WriteCorpus1Holdout($lineNum, "alignment", $tmpDir, "alignment.corpus"); + + # train + if ($isHiero == 1) { + $cmd = "$extractExec $tmpDir/target.corpus $tmpDir/source.corpus $tmpDir/alignment.corpus $tmpDir/extract --GZOutput"; + } + else { + # pb + $cmd = "$extractExec $tmpDir/target.corpus $tmpDir/source.corpus $tmpDir/alignment.corpus $tmpDir/extract 7 --GZOutput"; + } + $cmd = "$MOSES_DIR/scripts/generic/extract-parallel.perl 1 $SPLIT_EXEC $SORT_EXEC $cmd"; + print STDERR "Executing: $cmd\n"; + `$cmd`; + + $cmd = "$MOSES_DIR/scripts/generic/score-parallel.perl 1 $SORT_EXEC $MOSES_DIR/bin/score $tmpDir/extract.sorted.gz /dev/null $tmpDir/pt.half.gz $hieroFlag --NoLex 1"; + `$cmd`; + + $cmd = "$MOSES_DIR/scripts/generic/score-parallel.perl 1 $SORT_EXEC $MOSES_DIR/bin/score $tmpDir/extract.inv.sorted.gz /dev/null $tmpDir/pt.half.inv.gz --Inverse $hieroFlag --NoLex 1"; + `$cmd`; + + $cmd = "$MOSES_DIR/bin/consolidate $tmpDir/pt.half.gz $tmpDir/pt.half.inv.gz $tmpDir/pt $hieroFlag --OnlyDirect"; + `$cmd`; + + # decode + $cmd = "$decoderExec -f $iniPath -feature-overwrite \"TranslationModel0 path=$tmpDir/pt\" -i $tmpDir/source.1 -feature-add \"ConstrainedDecoding path=$tmpDir/target.1\" -v 2"; + print STDERR "Executing: $cmd\n"; + `$cmd`; + + `rm -rf $tmpDir`; +} + +close(SOURCE); +close(TARGET); +close(ALIGNMENT); + + +###################### +sub Write1Line +{ + my ($line, $tmpDir, $fileName) = @_; + + open (HANDLE, ">$tmpDir/$fileName"); + print HANDLE "$line\n"; + close (HANDLE); +} + +sub WriteCorpus1Holdout +{ + my ($holdoutLineNum, $inFilePath, $tmpDir, $outFileName) = @_; + + open (INFILE, "$inFilePath"); + open (OUTFILE, ">$tmpDir/$outFileName"); + + my $lineNum = 0; + while (my $line = ) { + chomp($line); + + if ($lineNum != $holdoutLineNum) { + print OUTFILE "$line\n"; + } + + ++$lineNum; + } + + close (OUTFILE); + close(INFILE); + +} + + diff --git a/contrib/other-builds/extract-mixed-syntax/learnable/run-parallel.perl b/contrib/other-builds/extract-mixed-syntax/learnable/run-parallel.perl new file mode 100755 index 000000000..fa271f9ad --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/learnable/run-parallel.perl @@ -0,0 +1,17 @@ +#! /usr/bin/perl -w + +my $iniPath = $ARGV[0]; + +my $SPLIT_LINES = 200; +my $lineCount = `cat source | wc -l`; +print STDERR "lineCount=$lineCount \n"; + +for (my $startLine = 0; $startLine < $lineCount; $startLine += $SPLIT_LINES) { + my $endLine = $startLine + $SPLIT_LINES; + + my $cmd = "../../scripts/reachable.perl $iniPath 1 moses_chart extract-rules tmp-reachable $startLine $endLine &>out.reachable.$startLine &"; + print STDERR "Executing: $cmd \n"; + system($cmd); + +} + diff --git a/contrib/other-builds/extract-mixed-syntax/pugiconfig.hpp b/contrib/other-builds/extract-mixed-syntax/pugiconfig.hpp new file mode 100644 index 000000000..c2196715c --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/pugiconfig.hpp @@ -0,0 +1,69 @@ +/** + * pugixml parser - version 1.2 + * -------------------------------------------------------- + * Copyright (C) 2006-2012, by Arseny Kapoulkine (arseny.kapoulkine@gmail.com) + * Report bugs and download new versions at http://pugixml.org/ + * + * This library is distributed under the MIT License. See notice at the end + * of this file. + * + * This work is based on the pugxml parser, which is: + * Copyright (C) 2003, by Kristen Wegner (kristen@tima.net) + */ + +#ifndef HEADER_PUGICONFIG_HPP +#define HEADER_PUGICONFIG_HPP + +// Uncomment this to enable wchar_t mode +// #define PUGIXML_WCHAR_MODE + +// Uncomment this to disable XPath +// #define PUGIXML_NO_XPATH + +// Uncomment this to disable STL +// #define PUGIXML_NO_STL + +// Uncomment this to disable exceptions +// #define PUGIXML_NO_EXCEPTIONS + +// Set this to control attributes for public classes/functions, i.e.: +// #define PUGIXML_API __declspec(dllexport) // to export all public symbols from DLL +// #define PUGIXML_CLASS __declspec(dllimport) // to import all classes from DLL +// #define PUGIXML_FUNCTION __fastcall // to set calling conventions to all public functions to fastcall +// In absence of PUGIXML_CLASS/PUGIXML_FUNCTION definitions PUGIXML_API is used instead + +// Uncomment this to switch to header-only version +// #define PUGIXML_HEADER_ONLY +// #include "pugixml.cpp" + +// Tune these constants to adjust memory-related behavior +// #define PUGIXML_MEMORY_PAGE_SIZE 32768 +// #define PUGIXML_MEMORY_OUTPUT_STACK 10240 +// #define PUGIXML_MEMORY_XPATH_PAGE_SIZE 4096 + +#endif + +/** + * Copyright (c) 2006-2012 Arseny Kapoulkine + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ diff --git a/contrib/other-builds/extract-mixed-syntax/pugixml.cpp b/contrib/other-builds/extract-mixed-syntax/pugixml.cpp new file mode 100644 index 000000000..4035ab1cf --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/pugixml.cpp @@ -0,0 +1,10250 @@ +/** + * pugixml parser - version 1.2 + * -------------------------------------------------------- + * Copyright (C) 2006-2012, by Arseny Kapoulkine (arseny.kapoulkine@gmail.com) + * Report bugs and download new versions at http://pugixml.org/ + * + * This library is distributed under the MIT License. See notice at the end + * of this file. + * + * This work is based on the pugxml parser, which is: + * Copyright (C) 2003, by Kristen Wegner (kristen@tima.net) + */ + +#ifndef SOURCE_PUGIXML_CPP +#define SOURCE_PUGIXML_CPP + +#include "pugixml.hpp" + +#include +#include +#include +#include +#include + +#ifndef PUGIXML_NO_XPATH +# include +# include +# ifdef PUGIXML_NO_EXCEPTIONS +# include +# endif +#endif + +#ifndef PUGIXML_NO_STL +# include +# include +# include +#endif + +// For placement new +#include + +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable: 4127) // conditional expression is constant +# pragma warning(disable: 4324) // structure was padded due to __declspec(align()) +# pragma warning(disable: 4611) // interaction between '_setjmp' and C++ object destruction is non-portable +# pragma warning(disable: 4702) // unreachable code +# pragma warning(disable: 4996) // this function or variable may be unsafe +# pragma warning(disable: 4793) // function compiled as native: presence of '_setjmp' makes a function unmanaged +#endif + +#ifdef __INTEL_COMPILER +# pragma warning(disable: 177) // function was declared but never referenced +# pragma warning(disable: 279) // controlling expression is constant +# pragma warning(disable: 1478 1786) // function was declared "deprecated" +# pragma warning(disable: 1684) // conversion from pointer to same-sized integral type +#endif + +#if defined(__BORLANDC__) && defined(PUGIXML_HEADER_ONLY) +# pragma warn -8080 // symbol is declared but never used; disabling this inside push/pop bracket does not make the warning go away +#endif + +#ifdef __BORLANDC__ +# pragma option push +# pragma warn -8008 // condition is always false +# pragma warn -8066 // unreachable code +#endif + +#ifdef __SNC__ +// Using diag_push/diag_pop does not disable the warnings inside templates due to a compiler bug +# pragma diag_suppress=178 // function was declared but never referenced +# pragma diag_suppress=237 // controlling expression is constant +#endif + +// Inlining controls +#if defined(_MSC_VER) && _MSC_VER >= 1300 +# define PUGI__NO_INLINE __declspec(noinline) +#elif defined(__GNUC__) +# define PUGI__NO_INLINE __attribute__((noinline)) +#else +# define PUGI__NO_INLINE +#endif + +// Simple static assertion +#define PUGI__STATIC_ASSERT(cond) { static const char condition_failed[(cond) ? 1 : -1] = {0}; (void)condition_failed[0]; } + +// Digital Mars C++ bug workaround for passing char loaded from memory via stack +#ifdef __DMC__ +# define PUGI__DMC_VOLATILE volatile +#else +# define PUGI__DMC_VOLATILE +#endif + +// Borland C++ bug workaround for not defining ::memcpy depending on header include order (can't always use std::memcpy because some compilers don't have it at all) +#if defined(__BORLANDC__) && !defined(__MEM_H_USING_LIST) +using std::memcpy; +using std::memmove; +#endif + +// In some environments MSVC is a compiler but the CRT lacks certain MSVC-specific features +#if defined(_MSC_VER) && !defined(__S3E__) +# define PUGI__MSVC_CRT_VERSION _MSC_VER +#endif + +#ifdef PUGIXML_HEADER_ONLY +# define PUGI__NS_BEGIN namespace pugi { namespace impl { +# define PUGI__NS_END } } +# define PUGI__FN inline +# define PUGI__FN_NO_INLINE inline +#else +# if defined(_MSC_VER) && _MSC_VER < 1300 // MSVC6 seems to have an amusing bug with anonymous namespaces inside namespaces +# define PUGI__NS_BEGIN namespace pugi { namespace impl { +# define PUGI__NS_END } } +# else +# define PUGI__NS_BEGIN namespace pugi { namespace impl { namespace { +# define PUGI__NS_END } } } +# endif +# define PUGI__FN +# define PUGI__FN_NO_INLINE PUGI__NO_INLINE +#endif + +// uintptr_t +#if !defined(_MSC_VER) || _MSC_VER >= 1600 +# include +#else +# ifndef _UINTPTR_T_DEFINED +// No native uintptr_t in MSVC6 and in some WinCE versions +typedef size_t uintptr_t; +#define _UINTPTR_T_DEFINED +# endif +PUGI__NS_BEGIN + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; +PUGI__NS_END +#endif + +// Memory allocation +PUGI__NS_BEGIN + PUGI__FN void* default_allocate(size_t size) + { + return malloc(size); + } + + PUGI__FN void default_deallocate(void* ptr) + { + free(ptr); + } + + template + struct xml_memory_management_function_storage + { + static allocation_function allocate; + static deallocation_function deallocate; + }; + + template allocation_function xml_memory_management_function_storage::allocate = default_allocate; + template deallocation_function xml_memory_management_function_storage::deallocate = default_deallocate; + + typedef xml_memory_management_function_storage xml_memory; +PUGI__NS_END + +// String utilities +PUGI__NS_BEGIN + // Get string length + PUGI__FN size_t strlength(const char_t* s) + { + assert(s); + + #ifdef PUGIXML_WCHAR_MODE + return wcslen(s); + #else + return strlen(s); + #endif + } + + // Compare two strings + PUGI__FN bool strequal(const char_t* src, const char_t* dst) + { + assert(src && dst); + + #ifdef PUGIXML_WCHAR_MODE + return wcscmp(src, dst) == 0; + #else + return strcmp(src, dst) == 0; + #endif + } + + // Compare lhs with [rhs_begin, rhs_end) + PUGI__FN bool strequalrange(const char_t* lhs, const char_t* rhs, size_t count) + { + for (size_t i = 0; i < count; ++i) + if (lhs[i] != rhs[i]) + return false; + + return lhs[count] == 0; + } + +#ifdef PUGIXML_WCHAR_MODE + // Convert string to wide string, assuming all symbols are ASCII + PUGI__FN void widen_ascii(wchar_t* dest, const char* source) + { + for (const char* i = source; *i; ++i) *dest++ = *i; + *dest = 0; + } +#endif +PUGI__NS_END + +#if !defined(PUGIXML_NO_STL) || !defined(PUGIXML_NO_XPATH) +// auto_ptr-like buffer holder for exception recovery +PUGI__NS_BEGIN + struct buffer_holder + { + void* data; + void (*deleter)(void*); + + buffer_holder(void* data_, void (*deleter_)(void*)): data(data_), deleter(deleter_) + { + } + + ~buffer_holder() + { + if (data) deleter(data); + } + + void* release() + { + void* result = data; + data = 0; + return result; + } + }; +PUGI__NS_END +#endif + +PUGI__NS_BEGIN + static const size_t xml_memory_page_size = + #ifdef PUGIXML_MEMORY_PAGE_SIZE + PUGIXML_MEMORY_PAGE_SIZE + #else + 32768 + #endif + ; + + static const uintptr_t xml_memory_page_alignment = 32; + static const uintptr_t xml_memory_page_pointer_mask = ~(xml_memory_page_alignment - 1); + static const uintptr_t xml_memory_page_name_allocated_mask = 16; + static const uintptr_t xml_memory_page_value_allocated_mask = 8; + static const uintptr_t xml_memory_page_type_mask = 7; + + struct xml_allocator; + + struct xml_memory_page + { + static xml_memory_page* construct(void* memory) + { + if (!memory) return 0; //$ redundant, left for performance + + xml_memory_page* result = static_cast(memory); + + result->allocator = 0; + result->memory = 0; + result->prev = 0; + result->next = 0; + result->busy_size = 0; + result->freed_size = 0; + + return result; + } + + xml_allocator* allocator; + + void* memory; + + xml_memory_page* prev; + xml_memory_page* next; + + size_t busy_size; + size_t freed_size; + + char data[1]; + }; + + struct xml_memory_string_header + { + uint16_t page_offset; // offset from page->data + uint16_t full_size; // 0 if string occupies whole page + }; + + struct xml_allocator + { + xml_allocator(xml_memory_page* root): _root(root), _busy_size(root->busy_size) + { + } + + xml_memory_page* allocate_page(size_t data_size) + { + size_t size = offsetof(xml_memory_page, data) + data_size; + + // allocate block with some alignment, leaving memory for worst-case padding + void* memory = xml_memory::allocate(size + xml_memory_page_alignment); + if (!memory) return 0; + + // align upwards to page boundary + void* page_memory = reinterpret_cast((reinterpret_cast(memory) + (xml_memory_page_alignment - 1)) & ~(xml_memory_page_alignment - 1)); + + // prepare page structure + xml_memory_page* page = xml_memory_page::construct(page_memory); + + page->memory = memory; + page->allocator = _root->allocator; + + return page; + } + + static void deallocate_page(xml_memory_page* page) + { + xml_memory::deallocate(page->memory); + } + + void* allocate_memory_oob(size_t size, xml_memory_page*& out_page); + + void* allocate_memory(size_t size, xml_memory_page*& out_page) + { + if (_busy_size + size > xml_memory_page_size) return allocate_memory_oob(size, out_page); + + void* buf = _root->data + _busy_size; + + _busy_size += size; + + out_page = _root; + + return buf; + } + + void deallocate_memory(void* ptr, size_t size, xml_memory_page* page) + { + if (page == _root) page->busy_size = _busy_size; + + assert(ptr >= page->data && ptr < page->data + page->busy_size); + (void)!ptr; + + page->freed_size += size; + assert(page->freed_size <= page->busy_size); + + if (page->freed_size == page->busy_size) + { + if (page->next == 0) + { + assert(_root == page); + + // top page freed, just reset sizes + page->busy_size = page->freed_size = 0; + _busy_size = 0; + } + else + { + assert(_root != page); + assert(page->prev); + + // remove from the list + page->prev->next = page->next; + page->next->prev = page->prev; + + // deallocate + deallocate_page(page); + } + } + } + + char_t* allocate_string(size_t length) + { + // allocate memory for string and header block + size_t size = sizeof(xml_memory_string_header) + length * sizeof(char_t); + + // round size up to pointer alignment boundary + size_t full_size = (size + (sizeof(void*) - 1)) & ~(sizeof(void*) - 1); + + xml_memory_page* page; + xml_memory_string_header* header = static_cast(allocate_memory(full_size, page)); + + if (!header) return 0; + + // setup header + ptrdiff_t page_offset = reinterpret_cast(header) - page->data; + + assert(page_offset >= 0 && page_offset < (1 << 16)); + header->page_offset = static_cast(page_offset); + + // full_size == 0 for large strings that occupy the whole page + assert(full_size < (1 << 16) || (page->busy_size == full_size && page_offset == 0)); + header->full_size = static_cast(full_size < (1 << 16) ? full_size : 0); + + // round-trip through void* to avoid 'cast increases required alignment of target type' warning + // header is guaranteed a pointer-sized alignment, which should be enough for char_t + return static_cast(static_cast(header + 1)); + } + + void deallocate_string(char_t* string) + { + // this function casts pointers through void* to avoid 'cast increases required alignment of target type' warnings + // we're guaranteed the proper (pointer-sized) alignment on the input string if it was allocated via allocate_string + + // get header + xml_memory_string_header* header = static_cast(static_cast(string)) - 1; + + // deallocate + size_t page_offset = offsetof(xml_memory_page, data) + header->page_offset; + xml_memory_page* page = reinterpret_cast(static_cast(reinterpret_cast(header) - page_offset)); + + // if full_size == 0 then this string occupies the whole page + size_t full_size = header->full_size == 0 ? page->busy_size : header->full_size; + + deallocate_memory(header, full_size, page); + } + + xml_memory_page* _root; + size_t _busy_size; + }; + + PUGI__FN_NO_INLINE void* xml_allocator::allocate_memory_oob(size_t size, xml_memory_page*& out_page) + { + const size_t large_allocation_threshold = xml_memory_page_size / 4; + + xml_memory_page* page = allocate_page(size <= large_allocation_threshold ? xml_memory_page_size : size); + out_page = page; + + if (!page) return 0; + + if (size <= large_allocation_threshold) + { + _root->busy_size = _busy_size; + + // insert page at the end of linked list + page->prev = _root; + _root->next = page; + _root = page; + + _busy_size = size; + } + else + { + // insert page before the end of linked list, so that it is deleted as soon as possible + // the last page is not deleted even if it's empty (see deallocate_memory) + assert(_root->prev); + + page->prev = _root->prev; + page->next = _root; + + _root->prev->next = page; + _root->prev = page; + } + + // allocate inside page + page->busy_size = size; + + return page->data; + } +PUGI__NS_END + +namespace pugi +{ + /// A 'name=value' XML attribute structure. + struct xml_attribute_struct + { + /// Default ctor + xml_attribute_struct(impl::xml_memory_page* page): header(reinterpret_cast(page)), name(0), value(0), prev_attribute_c(0), next_attribute(0) + { + } + + uintptr_t header; + + char_t* name; ///< Pointer to attribute name. + char_t* value; ///< Pointer to attribute value. + + xml_attribute_struct* prev_attribute_c; ///< Previous attribute (cyclic list) + xml_attribute_struct* next_attribute; ///< Next attribute + }; + + /// An XML document tree node. + struct xml_node_struct + { + /// Default ctor + /// \param type - node type + xml_node_struct(impl::xml_memory_page* page, xml_node_type type): header(reinterpret_cast(page) | (type - 1)), parent(0), name(0), value(0), first_child(0), prev_sibling_c(0), next_sibling(0), first_attribute(0) + { + } + + uintptr_t header; + + xml_node_struct* parent; ///< Pointer to parent + + char_t* name; ///< Pointer to element name. + char_t* value; ///< Pointer to any associated string data. + + xml_node_struct* first_child; ///< First child + + xml_node_struct* prev_sibling_c; ///< Left brother (cyclic list) + xml_node_struct* next_sibling; ///< Right brother + + xml_attribute_struct* first_attribute; ///< First attribute + }; +} + +PUGI__NS_BEGIN + struct xml_document_struct: public xml_node_struct, public xml_allocator + { + xml_document_struct(xml_memory_page* page): xml_node_struct(page, node_document), xml_allocator(page), buffer(0) + { + } + + const char_t* buffer; + }; + + inline xml_allocator& get_allocator(const xml_node_struct* node) + { + assert(node); + + return *reinterpret_cast(node->header & xml_memory_page_pointer_mask)->allocator; + } +PUGI__NS_END + +// Low-level DOM operations +PUGI__NS_BEGIN + inline xml_attribute_struct* allocate_attribute(xml_allocator& alloc) + { + xml_memory_page* page; + void* memory = alloc.allocate_memory(sizeof(xml_attribute_struct), page); + + return new (memory) xml_attribute_struct(page); + } + + inline xml_node_struct* allocate_node(xml_allocator& alloc, xml_node_type type) + { + xml_memory_page* page; + void* memory = alloc.allocate_memory(sizeof(xml_node_struct), page); + + return new (memory) xml_node_struct(page, type); + } + + inline void destroy_attribute(xml_attribute_struct* a, xml_allocator& alloc) + { + uintptr_t header = a->header; + + if (header & impl::xml_memory_page_name_allocated_mask) alloc.deallocate_string(a->name); + if (header & impl::xml_memory_page_value_allocated_mask) alloc.deallocate_string(a->value); + + alloc.deallocate_memory(a, sizeof(xml_attribute_struct), reinterpret_cast(header & xml_memory_page_pointer_mask)); + } + + inline void destroy_node(xml_node_struct* n, xml_allocator& alloc) + { + uintptr_t header = n->header; + + if (header & impl::xml_memory_page_name_allocated_mask) alloc.deallocate_string(n->name); + if (header & impl::xml_memory_page_value_allocated_mask) alloc.deallocate_string(n->value); + + for (xml_attribute_struct* attr = n->first_attribute; attr; ) + { + xml_attribute_struct* next = attr->next_attribute; + + destroy_attribute(attr, alloc); + + attr = next; + } + + for (xml_node_struct* child = n->first_child; child; ) + { + xml_node_struct* next = child->next_sibling; + + destroy_node(child, alloc); + + child = next; + } + + alloc.deallocate_memory(n, sizeof(xml_node_struct), reinterpret_cast(header & xml_memory_page_pointer_mask)); + } + + PUGI__FN_NO_INLINE xml_node_struct* append_node(xml_node_struct* node, xml_allocator& alloc, xml_node_type type = node_element) + { + xml_node_struct* child = allocate_node(alloc, type); + if (!child) return 0; + + child->parent = node; + + xml_node_struct* first_child = node->first_child; + + if (first_child) + { + xml_node_struct* last_child = first_child->prev_sibling_c; + + last_child->next_sibling = child; + child->prev_sibling_c = last_child; + first_child->prev_sibling_c = child; + } + else + { + node->first_child = child; + child->prev_sibling_c = child; + } + + return child; + } + + PUGI__FN_NO_INLINE xml_attribute_struct* append_attribute_ll(xml_node_struct* node, xml_allocator& alloc) + { + xml_attribute_struct* a = allocate_attribute(alloc); + if (!a) return 0; + + xml_attribute_struct* first_attribute = node->first_attribute; + + if (first_attribute) + { + xml_attribute_struct* last_attribute = first_attribute->prev_attribute_c; + + last_attribute->next_attribute = a; + a->prev_attribute_c = last_attribute; + first_attribute->prev_attribute_c = a; + } + else + { + node->first_attribute = a; + a->prev_attribute_c = a; + } + + return a; + } +PUGI__NS_END + +// Helper classes for code generation +PUGI__NS_BEGIN + struct opt_false + { + enum { value = 0 }; + }; + + struct opt_true + { + enum { value = 1 }; + }; +PUGI__NS_END + +// Unicode utilities +PUGI__NS_BEGIN + inline uint16_t endian_swap(uint16_t value) + { + return static_cast(((value & 0xff) << 8) | (value >> 8)); + } + + inline uint32_t endian_swap(uint32_t value) + { + return ((value & 0xff) << 24) | ((value & 0xff00) << 8) | ((value & 0xff0000) >> 8) | (value >> 24); + } + + struct utf8_counter + { + typedef size_t value_type; + + static value_type low(value_type result, uint32_t ch) + { + // U+0000..U+007F + if (ch < 0x80) return result + 1; + // U+0080..U+07FF + else if (ch < 0x800) return result + 2; + // U+0800..U+FFFF + else return result + 3; + } + + static value_type high(value_type result, uint32_t) + { + // U+10000..U+10FFFF + return result + 4; + } + }; + + struct utf8_writer + { + typedef uint8_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + // U+0000..U+007F + if (ch < 0x80) + { + *result = static_cast(ch); + return result + 1; + } + // U+0080..U+07FF + else if (ch < 0x800) + { + result[0] = static_cast(0xC0 | (ch >> 6)); + result[1] = static_cast(0x80 | (ch & 0x3F)); + return result + 2; + } + // U+0800..U+FFFF + else + { + result[0] = static_cast(0xE0 | (ch >> 12)); + result[1] = static_cast(0x80 | ((ch >> 6) & 0x3F)); + result[2] = static_cast(0x80 | (ch & 0x3F)); + return result + 3; + } + } + + static value_type high(value_type result, uint32_t ch) + { + // U+10000..U+10FFFF + result[0] = static_cast(0xF0 | (ch >> 18)); + result[1] = static_cast(0x80 | ((ch >> 12) & 0x3F)); + result[2] = static_cast(0x80 | ((ch >> 6) & 0x3F)); + result[3] = static_cast(0x80 | (ch & 0x3F)); + return result + 4; + } + + static value_type any(value_type result, uint32_t ch) + { + return (ch < 0x10000) ? low(result, ch) : high(result, ch); + } + }; + + struct utf16_counter + { + typedef size_t value_type; + + static value_type low(value_type result, uint32_t) + { + return result + 1; + } + + static value_type high(value_type result, uint32_t) + { + return result + 2; + } + }; + + struct utf16_writer + { + typedef uint16_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + *result = static_cast(ch); + + return result + 1; + } + + static value_type high(value_type result, uint32_t ch) + { + uint32_t msh = static_cast(ch - 0x10000) >> 10; + uint32_t lsh = static_cast(ch - 0x10000) & 0x3ff; + + result[0] = static_cast(0xD800 + msh); + result[1] = static_cast(0xDC00 + lsh); + + return result + 2; + } + + static value_type any(value_type result, uint32_t ch) + { + return (ch < 0x10000) ? low(result, ch) : high(result, ch); + } + }; + + struct utf32_counter + { + typedef size_t value_type; + + static value_type low(value_type result, uint32_t) + { + return result + 1; + } + + static value_type high(value_type result, uint32_t) + { + return result + 1; + } + }; + + struct utf32_writer + { + typedef uint32_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + *result = ch; + + return result + 1; + } + + static value_type high(value_type result, uint32_t ch) + { + *result = ch; + + return result + 1; + } + + static value_type any(value_type result, uint32_t ch) + { + *result = ch; + + return result + 1; + } + }; + + struct latin1_writer + { + typedef uint8_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + *result = static_cast(ch > 255 ? '?' : ch); + + return result + 1; + } + + static value_type high(value_type result, uint32_t ch) + { + (void)ch; + + *result = '?'; + + return result + 1; + } + }; + + template struct wchar_selector; + + template <> struct wchar_selector<2> + { + typedef uint16_t type; + typedef utf16_counter counter; + typedef utf16_writer writer; + }; + + template <> struct wchar_selector<4> + { + typedef uint32_t type; + typedef utf32_counter counter; + typedef utf32_writer writer; + }; + + typedef wchar_selector::counter wchar_counter; + typedef wchar_selector::writer wchar_writer; + + template struct utf_decoder + { + static inline typename Traits::value_type decode_utf8_block(const uint8_t* data, size_t size, typename Traits::value_type result) + { + const uint8_t utf8_byte_mask = 0x3f; + + while (size) + { + uint8_t lead = *data; + + // 0xxxxxxx -> U+0000..U+007F + if (lead < 0x80) + { + result = Traits::low(result, lead); + data += 1; + size -= 1; + + // process aligned single-byte (ascii) blocks + if ((reinterpret_cast(data) & 3) == 0) + { + // round-trip through void* to silence 'cast increases required alignment of target type' warnings + while (size >= 4 && (*static_cast(static_cast(data)) & 0x80808080) == 0) + { + result = Traits::low(result, data[0]); + result = Traits::low(result, data[1]); + result = Traits::low(result, data[2]); + result = Traits::low(result, data[3]); + data += 4; + size -= 4; + } + } + } + // 110xxxxx -> U+0080..U+07FF + else if (static_cast(lead - 0xC0) < 0x20 && size >= 2 && (data[1] & 0xc0) == 0x80) + { + result = Traits::low(result, ((lead & ~0xC0) << 6) | (data[1] & utf8_byte_mask)); + data += 2; + size -= 2; + } + // 1110xxxx -> U+0800-U+FFFF + else if (static_cast(lead - 0xE0) < 0x10 && size >= 3 && (data[1] & 0xc0) == 0x80 && (data[2] & 0xc0) == 0x80) + { + result = Traits::low(result, ((lead & ~0xE0) << 12) | ((data[1] & utf8_byte_mask) << 6) | (data[2] & utf8_byte_mask)); + data += 3; + size -= 3; + } + // 11110xxx -> U+10000..U+10FFFF + else if (static_cast(lead - 0xF0) < 0x08 && size >= 4 && (data[1] & 0xc0) == 0x80 && (data[2] & 0xc0) == 0x80 && (data[3] & 0xc0) == 0x80) + { + result = Traits::high(result, ((lead & ~0xF0) << 18) | ((data[1] & utf8_byte_mask) << 12) | ((data[2] & utf8_byte_mask) << 6) | (data[3] & utf8_byte_mask)); + data += 4; + size -= 4; + } + // 10xxxxxx or 11111xxx -> invalid + else + { + data += 1; + size -= 1; + } + } + + return result; + } + + static inline typename Traits::value_type decode_utf16_block(const uint16_t* data, size_t size, typename Traits::value_type result) + { + const uint16_t* end = data + size; + + while (data < end) + { + uint16_t lead = opt_swap::value ? endian_swap(*data) : *data; + + // U+0000..U+D7FF + if (lead < 0xD800) + { + result = Traits::low(result, lead); + data += 1; + } + // U+E000..U+FFFF + else if (static_cast(lead - 0xE000) < 0x2000) + { + result = Traits::low(result, lead); + data += 1; + } + // surrogate pair lead + else if (static_cast(lead - 0xD800) < 0x400 && data + 1 < end) + { + uint16_t next = opt_swap::value ? endian_swap(data[1]) : data[1]; + + if (static_cast(next - 0xDC00) < 0x400) + { + result = Traits::high(result, 0x10000 + ((lead & 0x3ff) << 10) + (next & 0x3ff)); + data += 2; + } + else + { + data += 1; + } + } + else + { + data += 1; + } + } + + return result; + } + + static inline typename Traits::value_type decode_utf32_block(const uint32_t* data, size_t size, typename Traits::value_type result) + { + const uint32_t* end = data + size; + + while (data < end) + { + uint32_t lead = opt_swap::value ? endian_swap(*data) : *data; + + // U+0000..U+FFFF + if (lead < 0x10000) + { + result = Traits::low(result, lead); + data += 1; + } + // U+10000..U+10FFFF + else + { + result = Traits::high(result, lead); + data += 1; + } + } + + return result; + } + + static inline typename Traits::value_type decode_latin1_block(const uint8_t* data, size_t size, typename Traits::value_type result) + { + for (size_t i = 0; i < size; ++i) + { + result = Traits::low(result, data[i]); + } + + return result; + } + + static inline typename Traits::value_type decode_wchar_block_impl(const uint16_t* data, size_t size, typename Traits::value_type result) + { + return decode_utf16_block(data, size, result); + } + + static inline typename Traits::value_type decode_wchar_block_impl(const uint32_t* data, size_t size, typename Traits::value_type result) + { + return decode_utf32_block(data, size, result); + } + + static inline typename Traits::value_type decode_wchar_block(const wchar_t* data, size_t size, typename Traits::value_type result) + { + return decode_wchar_block_impl(reinterpret_cast::type*>(data), size, result); + } + }; + + template PUGI__FN void convert_utf_endian_swap(T* result, const T* data, size_t length) + { + for (size_t i = 0; i < length; ++i) result[i] = endian_swap(data[i]); + } + +#ifdef PUGIXML_WCHAR_MODE + PUGI__FN void convert_wchar_endian_swap(wchar_t* result, const wchar_t* data, size_t length) + { + for (size_t i = 0; i < length; ++i) result[i] = static_cast(endian_swap(static_cast::type>(data[i]))); + } +#endif +PUGI__NS_END + +PUGI__NS_BEGIN + enum chartype_t + { + ct_parse_pcdata = 1, // \0, &, \r, < + ct_parse_attr = 2, // \0, &, \r, ', " + ct_parse_attr_ws = 4, // \0, &, \r, ', ", \n, tab + ct_space = 8, // \r, \n, space, tab + ct_parse_cdata = 16, // \0, ], >, \r + ct_parse_comment = 32, // \0, -, >, \r + ct_symbol = 64, // Any symbol > 127, a-z, A-Z, 0-9, _, :, -, . + ct_start_symbol = 128 // Any symbol > 127, a-z, A-Z, _, : + }; + + static const unsigned char chartype_table[256] = + { + 55, 0, 0, 0, 0, 0, 0, 0, 0, 12, 12, 0, 0, 63, 0, 0, // 0-15 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 16-31 + 8, 0, 6, 0, 0, 0, 7, 6, 0, 0, 0, 0, 0, 96, 64, 0, // 32-47 + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 192, 0, 1, 0, 48, 0, // 48-63 + 0, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, // 64-79 + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 0, 0, 16, 0, 192, // 80-95 + 0, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, // 96-111 + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 0, 0, 0, 0, 0, // 112-127 + + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, // 128+ + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192 + }; + + enum chartypex_t + { + ctx_special_pcdata = 1, // Any symbol >= 0 and < 32 (except \t, \r, \n), &, <, > + ctx_special_attr = 2, // Any symbol >= 0 and < 32 (except \t), &, <, >, " + ctx_start_symbol = 4, // Any symbol > 127, a-z, A-Z, _ + ctx_digit = 8, // 0-9 + ctx_symbol = 16 // Any symbol > 127, a-z, A-Z, 0-9, _, -, . + }; + + static const unsigned char chartypex_table[256] = + { + 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 2, 3, 3, 2, 3, 3, // 0-15 + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 16-31 + 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 16, 16, 0, // 32-47 + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 0, 0, 3, 0, 3, 0, // 48-63 + + 0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, // 64-79 + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 0, 0, 0, 0, 20, // 80-95 + 0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, // 96-111 + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 0, 0, 0, 0, 0, // 112-127 + + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, // 128+ + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20 + }; + +#ifdef PUGIXML_WCHAR_MODE + #define PUGI__IS_CHARTYPE_IMPL(c, ct, table) ((static_cast(c) < 128 ? table[static_cast(c)] : table[128]) & (ct)) +#else + #define PUGI__IS_CHARTYPE_IMPL(c, ct, table) (table[static_cast(c)] & (ct)) +#endif + + #define PUGI__IS_CHARTYPE(c, ct) PUGI__IS_CHARTYPE_IMPL(c, ct, chartype_table) + #define PUGI__IS_CHARTYPEX(c, ct) PUGI__IS_CHARTYPE_IMPL(c, ct, chartypex_table) + + PUGI__FN bool is_little_endian() + { + unsigned int ui = 1; + + return *reinterpret_cast(&ui) == 1; + } + + PUGI__FN xml_encoding get_wchar_encoding() + { + PUGI__STATIC_ASSERT(sizeof(wchar_t) == 2 || sizeof(wchar_t) == 4); + + if (sizeof(wchar_t) == 2) + return is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + else + return is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + } + + PUGI__FN xml_encoding guess_buffer_encoding(uint8_t d0, uint8_t d1, uint8_t d2, uint8_t d3) + { + // look for BOM in first few bytes + if (d0 == 0 && d1 == 0 && d2 == 0xfe && d3 == 0xff) return encoding_utf32_be; + if (d0 == 0xff && d1 == 0xfe && d2 == 0 && d3 == 0) return encoding_utf32_le; + if (d0 == 0xfe && d1 == 0xff) return encoding_utf16_be; + if (d0 == 0xff && d1 == 0xfe) return encoding_utf16_le; + if (d0 == 0xef && d1 == 0xbb && d2 == 0xbf) return encoding_utf8; + + // look for <, (contents); + + PUGI__DMC_VOLATILE uint8_t d0 = data[0], d1 = data[1], d2 = data[2], d3 = data[3]; + + return guess_buffer_encoding(d0, d1, d2, d3); + } + + PUGI__FN bool get_mutable_buffer(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, bool is_mutable) + { + if (is_mutable) + { + out_buffer = static_cast(const_cast(contents)); + } + else + { + void* buffer = xml_memory::allocate(size > 0 ? size : 1); + if (!buffer) return false; + + memcpy(buffer, contents, size); + + out_buffer = static_cast(buffer); + } + + out_length = size / sizeof(char_t); + + return true; + } + +#ifdef PUGIXML_WCHAR_MODE + PUGI__FN bool need_endian_swap_utf(xml_encoding le, xml_encoding re) + { + return (le == encoding_utf16_be && re == encoding_utf16_le) || (le == encoding_utf16_le && re == encoding_utf16_be) || + (le == encoding_utf32_be && re == encoding_utf32_le) || (le == encoding_utf32_le && re == encoding_utf32_be); + } + + PUGI__FN bool convert_buffer_endian_swap(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, bool is_mutable) + { + const char_t* data = static_cast(contents); + + if (is_mutable) + { + out_buffer = const_cast(data); + } + else + { + out_buffer = static_cast(xml_memory::allocate(size > 0 ? size : 1)); + if (!out_buffer) return false; + } + + out_length = size / sizeof(char_t); + + convert_wchar_endian_swap(out_buffer, data, out_length); + + return true; + } + + PUGI__FN bool convert_buffer_utf8(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size) + { + const uint8_t* data = static_cast(contents); + + // first pass: get length in wchar_t units + out_length = utf_decoder::decode_utf8_block(data, size, 0); + + // allocate buffer of suitable length + out_buffer = static_cast(xml_memory::allocate((out_length > 0 ? out_length : 1) * sizeof(char_t))); + if (!out_buffer) return false; + + // second pass: convert utf8 input to wchar_t + wchar_writer::value_type out_begin = reinterpret_cast(out_buffer); + wchar_writer::value_type out_end = utf_decoder::decode_utf8_block(data, size, out_begin); + + assert(out_end == out_begin + out_length); + (void)!out_end; + + return true; + } + + template PUGI__FN bool convert_buffer_utf16(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, opt_swap) + { + const uint16_t* data = static_cast(contents); + size_t length = size / sizeof(uint16_t); + + // first pass: get length in wchar_t units + out_length = utf_decoder::decode_utf16_block(data, length, 0); + + // allocate buffer of suitable length + out_buffer = static_cast(xml_memory::allocate((out_length > 0 ? out_length : 1) * sizeof(char_t))); + if (!out_buffer) return false; + + // second pass: convert utf16 input to wchar_t + wchar_writer::value_type out_begin = reinterpret_cast(out_buffer); + wchar_writer::value_type out_end = utf_decoder::decode_utf16_block(data, length, out_begin); + + assert(out_end == out_begin + out_length); + (void)!out_end; + + return true; + } + + template PUGI__FN bool convert_buffer_utf32(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, opt_swap) + { + const uint32_t* data = static_cast(contents); + size_t length = size / sizeof(uint32_t); + + // first pass: get length in wchar_t units + out_length = utf_decoder::decode_utf32_block(data, length, 0); + + // allocate buffer of suitable length + out_buffer = static_cast(xml_memory::allocate((out_length > 0 ? out_length : 1) * sizeof(char_t))); + if (!out_buffer) return false; + + // second pass: convert utf32 input to wchar_t + wchar_writer::value_type out_begin = reinterpret_cast(out_buffer); + wchar_writer::value_type out_end = utf_decoder::decode_utf32_block(data, length, out_begin); + + assert(out_end == out_begin + out_length); + (void)!out_end; + + return true; + } + + PUGI__FN bool convert_buffer_latin1(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size) + { + const uint8_t* data = static_cast(contents); + + // get length in wchar_t units + out_length = size; + + // allocate buffer of suitable length + out_buffer = static_cast(xml_memory::allocate((out_length > 0 ? out_length : 1) * sizeof(char_t))); + if (!out_buffer) return false; + + // convert latin1 input to wchar_t + wchar_writer::value_type out_begin = reinterpret_cast(out_buffer); + wchar_writer::value_type out_end = utf_decoder::decode_latin1_block(data, size, out_begin); + + assert(out_end == out_begin + out_length); + (void)!out_end; + + return true; + } + + PUGI__FN bool convert_buffer(char_t*& out_buffer, size_t& out_length, xml_encoding encoding, const void* contents, size_t size, bool is_mutable) + { + // get native encoding + xml_encoding wchar_encoding = get_wchar_encoding(); + + // fast path: no conversion required + if (encoding == wchar_encoding) return get_mutable_buffer(out_buffer, out_length, contents, size, is_mutable); + + // only endian-swapping is required + if (need_endian_swap_utf(encoding, wchar_encoding)) return convert_buffer_endian_swap(out_buffer, out_length, contents, size, is_mutable); + + // source encoding is utf8 + if (encoding == encoding_utf8) return convert_buffer_utf8(out_buffer, out_length, contents, size); + + // source encoding is utf16 + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + return (native_encoding == encoding) ? + convert_buffer_utf16(out_buffer, out_length, contents, size, opt_false()) : + convert_buffer_utf16(out_buffer, out_length, contents, size, opt_true()); + } + + // source encoding is utf32 + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + return (native_encoding == encoding) ? + convert_buffer_utf32(out_buffer, out_length, contents, size, opt_false()) : + convert_buffer_utf32(out_buffer, out_length, contents, size, opt_true()); + } + + // source encoding is latin1 + if (encoding == encoding_latin1) return convert_buffer_latin1(out_buffer, out_length, contents, size); + + assert(!"Invalid encoding"); + return false; + } +#else + template PUGI__FN bool convert_buffer_utf16(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, opt_swap) + { + const uint16_t* data = static_cast(contents); + size_t length = size / sizeof(uint16_t); + + // first pass: get length in utf8 units + out_length = utf_decoder::decode_utf16_block(data, length, 0); + + // allocate buffer of suitable length + out_buffer = static_cast(xml_memory::allocate((out_length > 0 ? out_length : 1) * sizeof(char_t))); + if (!out_buffer) return false; + + // second pass: convert utf16 input to utf8 + uint8_t* out_begin = reinterpret_cast(out_buffer); + uint8_t* out_end = utf_decoder::decode_utf16_block(data, length, out_begin); + + assert(out_end == out_begin + out_length); + (void)!out_end; + + return true; + } + + template PUGI__FN bool convert_buffer_utf32(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, opt_swap) + { + const uint32_t* data = static_cast(contents); + size_t length = size / sizeof(uint32_t); + + // first pass: get length in utf8 units + out_length = utf_decoder::decode_utf32_block(data, length, 0); + + // allocate buffer of suitable length + out_buffer = static_cast(xml_memory::allocate((out_length > 0 ? out_length : 1) * sizeof(char_t))); + if (!out_buffer) return false; + + // second pass: convert utf32 input to utf8 + uint8_t* out_begin = reinterpret_cast(out_buffer); + uint8_t* out_end = utf_decoder::decode_utf32_block(data, length, out_begin); + + assert(out_end == out_begin + out_length); + (void)!out_end; + + return true; + } + + PUGI__FN size_t get_latin1_7bit_prefix_length(const uint8_t* data, size_t size) + { + for (size_t i = 0; i < size; ++i) + if (data[i] > 127) + return i; + + return size; + } + + PUGI__FN bool convert_buffer_latin1(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, bool is_mutable) + { + const uint8_t* data = static_cast(contents); + + // get size of prefix that does not need utf8 conversion + size_t prefix_length = get_latin1_7bit_prefix_length(data, size); + assert(prefix_length <= size); + + const uint8_t* postfix = data + prefix_length; + size_t postfix_length = size - prefix_length; + + // if no conversion is needed, just return the original buffer + if (postfix_length == 0) return get_mutable_buffer(out_buffer, out_length, contents, size, is_mutable); + + // first pass: get length in utf8 units + out_length = prefix_length + utf_decoder::decode_latin1_block(postfix, postfix_length, 0); + + // allocate buffer of suitable length + out_buffer = static_cast(xml_memory::allocate((out_length > 0 ? out_length : 1) * sizeof(char_t))); + if (!out_buffer) return false; + + // second pass: convert latin1 input to utf8 + memcpy(out_buffer, data, prefix_length); + + uint8_t* out_begin = reinterpret_cast(out_buffer); + uint8_t* out_end = utf_decoder::decode_latin1_block(postfix, postfix_length, out_begin + prefix_length); + + assert(out_end == out_begin + out_length); + (void)!out_end; + + return true; + } + + PUGI__FN bool convert_buffer(char_t*& out_buffer, size_t& out_length, xml_encoding encoding, const void* contents, size_t size, bool is_mutable) + { + // fast path: no conversion required + if (encoding == encoding_utf8) return get_mutable_buffer(out_buffer, out_length, contents, size, is_mutable); + + // source encoding is utf16 + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + return (native_encoding == encoding) ? + convert_buffer_utf16(out_buffer, out_length, contents, size, opt_false()) : + convert_buffer_utf16(out_buffer, out_length, contents, size, opt_true()); + } + + // source encoding is utf32 + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + return (native_encoding == encoding) ? + convert_buffer_utf32(out_buffer, out_length, contents, size, opt_false()) : + convert_buffer_utf32(out_buffer, out_length, contents, size, opt_true()); + } + + // source encoding is latin1 + if (encoding == encoding_latin1) return convert_buffer_latin1(out_buffer, out_length, contents, size, is_mutable); + + assert(!"Invalid encoding"); + return false; + } +#endif + + PUGI__FN size_t as_utf8_begin(const wchar_t* str, size_t length) + { + // get length in utf8 characters + return utf_decoder::decode_wchar_block(str, length, 0); + } + + PUGI__FN void as_utf8_end(char* buffer, size_t size, const wchar_t* str, size_t length) + { + // convert to utf8 + uint8_t* begin = reinterpret_cast(buffer); + uint8_t* end = utf_decoder::decode_wchar_block(str, length, begin); + + assert(begin + size == end); + (void)!end; + + // zero-terminate + buffer[size] = 0; + } + +#ifndef PUGIXML_NO_STL + PUGI__FN std::string as_utf8_impl(const wchar_t* str, size_t length) + { + // first pass: get length in utf8 characters + size_t size = as_utf8_begin(str, length); + + // allocate resulting string + std::string result; + result.resize(size); + + // second pass: convert to utf8 + if (size > 0) as_utf8_end(&result[0], size, str, length); + + return result; + } + + PUGI__FN std::basic_string as_wide_impl(const char* str, size_t size) + { + const uint8_t* data = reinterpret_cast(str); + + // first pass: get length in wchar_t units + size_t length = utf_decoder::decode_utf8_block(data, size, 0); + + // allocate resulting string + std::basic_string result; + result.resize(length); + + // second pass: convert to wchar_t + if (length > 0) + { + wchar_writer::value_type begin = reinterpret_cast(&result[0]); + wchar_writer::value_type end = utf_decoder::decode_utf8_block(data, size, begin); + + assert(begin + length == end); + (void)!end; + } + + return result; + } +#endif + + inline bool strcpy_insitu_allow(size_t length, uintptr_t allocated, char_t* target) + { + assert(target); + size_t target_length = strlength(target); + + // always reuse document buffer memory if possible + if (!allocated) return target_length >= length; + + // reuse heap memory if waste is not too great + const size_t reuse_threshold = 32; + + return target_length >= length && (target_length < reuse_threshold || target_length - length < target_length / 2); + } + + PUGI__FN bool strcpy_insitu(char_t*& dest, uintptr_t& header, uintptr_t header_mask, const char_t* source) + { + size_t source_length = strlength(source); + + if (source_length == 0) + { + // empty string and null pointer are equivalent, so just deallocate old memory + xml_allocator* alloc = reinterpret_cast(header & xml_memory_page_pointer_mask)->allocator; + + if (header & header_mask) alloc->deallocate_string(dest); + + // mark the string as not allocated + dest = 0; + header &= ~header_mask; + + return true; + } + else if (dest && strcpy_insitu_allow(source_length, header & header_mask, dest)) + { + // we can reuse old buffer, so just copy the new data (including zero terminator) + memcpy(dest, source, (source_length + 1) * sizeof(char_t)); + + return true; + } + else + { + xml_allocator* alloc = reinterpret_cast(header & xml_memory_page_pointer_mask)->allocator; + + // allocate new buffer + char_t* buf = alloc->allocate_string(source_length + 1); + if (!buf) return false; + + // copy the string (including zero terminator) + memcpy(buf, source, (source_length + 1) * sizeof(char_t)); + + // deallocate old buffer (*after* the above to protect against overlapping memory and/or allocation failures) + if (header & header_mask) alloc->deallocate_string(dest); + + // the string is now allocated, so set the flag + dest = buf; + header |= header_mask; + + return true; + } + } + + struct gap + { + char_t* end; + size_t size; + + gap(): end(0), size(0) + { + } + + // Push new gap, move s count bytes further (skipping the gap). + // Collapse previous gap. + void push(char_t*& s, size_t count) + { + if (end) // there was a gap already; collapse it + { + // Move [old_gap_end, new_gap_start) to [old_gap_start, ...) + assert(s >= end); + memmove(end - size, end, reinterpret_cast(s) - reinterpret_cast(end)); + } + + s += count; // end of current gap + + // "merge" two gaps + end = s; + size += count; + } + + // Collapse all gaps, return past-the-end pointer + char_t* flush(char_t* s) + { + if (end) + { + // Move [old_gap_end, current_pos) to [old_gap_start, ...) + assert(s >= end); + memmove(end - size, end, reinterpret_cast(s) - reinterpret_cast(end)); + + return s - size; + } + else return s; + } + }; + + PUGI__FN char_t* strconv_escape(char_t* s, gap& g) + { + char_t* stre = s + 1; + + switch (*stre) + { + case '#': // &#... + { + unsigned int ucsc = 0; + + if (stre[1] == 'x') // &#x... (hex code) + { + stre += 2; + + char_t ch = *stre; + + if (ch == ';') return stre; + + for (;;) + { + if (static_cast(ch - '0') <= 9) + ucsc = 16 * ucsc + (ch - '0'); + else if (static_cast((ch | ' ') - 'a') <= 5) + ucsc = 16 * ucsc + ((ch | ' ') - 'a' + 10); + else if (ch == ';') + break; + else // cancel + return stre; + + ch = *++stre; + } + + ++stre; + } + else // &#... (dec code) + { + char_t ch = *++stre; + + if (ch == ';') return stre; + + for (;;) + { + if (static_cast(ch - '0') <= 9) + ucsc = 10 * ucsc + (ch - '0'); + else if (ch == ';') + break; + else // cancel + return stre; + + ch = *++stre; + } + + ++stre; + } + + #ifdef PUGIXML_WCHAR_MODE + s = reinterpret_cast(wchar_writer::any(reinterpret_cast(s), ucsc)); + #else + s = reinterpret_cast(utf8_writer::any(reinterpret_cast(s), ucsc)); + #endif + + g.push(s, stre - s); + return stre; + } + + case 'a': // &a + { + ++stre; + + if (*stre == 'm') // &am + { + if (*++stre == 'p' && *++stre == ';') // & + { + *s++ = '&'; + ++stre; + + g.push(s, stre - s); + return stre; + } + } + else if (*stre == 'p') // &ap + { + if (*++stre == 'o' && *++stre == 's' && *++stre == ';') // ' + { + *s++ = '\''; + ++stre; + + g.push(s, stre - s); + return stre; + } + } + break; + } + + case 'g': // &g + { + if (*++stre == 't' && *++stre == ';') // > + { + *s++ = '>'; + ++stre; + + g.push(s, stre - s); + return stre; + } + break; + } + + case 'l': // &l + { + if (*++stre == 't' && *++stre == ';') // < + { + *s++ = '<'; + ++stre; + + g.push(s, stre - s); + return stre; + } + break; + } + + case 'q': // &q + { + if (*++stre == 'u' && *++stre == 'o' && *++stre == 't' && *++stre == ';') // " + { + *s++ = '"'; + ++stre; + + g.push(s, stre - s); + return stre; + } + break; + } + + default: + break; + } + + return stre; + } + + // Utility macro for last character handling + #define ENDSWITH(c, e) ((c) == (e) || ((c) == 0 && endch == (e))) + + PUGI__FN char_t* strconv_comment(char_t* s, char_t endch) + { + gap g; + + while (true) + { + while (!PUGI__IS_CHARTYPE(*s, ct_parse_comment)) ++s; + + if (*s == '\r') // Either a single 0x0d or 0x0d 0x0a pair + { + *s++ = '\n'; // replace first one with 0x0a + + if (*s == '\n') g.push(s, 1); + } + else if (s[0] == '-' && s[1] == '-' && ENDSWITH(s[2], '>')) // comment ends here + { + *g.flush(s) = 0; + + return s + (s[2] == '>' ? 3 : 2); + } + else if (*s == 0) + { + return 0; + } + else ++s; + } + } + + PUGI__FN char_t* strconv_cdata(char_t* s, char_t endch) + { + gap g; + + while (true) + { + while (!PUGI__IS_CHARTYPE(*s, ct_parse_cdata)) ++s; + + if (*s == '\r') // Either a single 0x0d or 0x0d 0x0a pair + { + *s++ = '\n'; // replace first one with 0x0a + + if (*s == '\n') g.push(s, 1); + } + else if (s[0] == ']' && s[1] == ']' && ENDSWITH(s[2], '>')) // CDATA ends here + { + *g.flush(s) = 0; + + return s + 1; + } + else if (*s == 0) + { + return 0; + } + else ++s; + } + } + + typedef char_t* (*strconv_pcdata_t)(char_t*); + + template struct strconv_pcdata_impl + { + static char_t* parse(char_t* s) + { + gap g; + + while (true) + { + while (!PUGI__IS_CHARTYPE(*s, ct_parse_pcdata)) ++s; + + if (*s == '<') // PCDATA ends here + { + *g.flush(s) = 0; + + return s + 1; + } + else if (opt_eol::value && *s == '\r') // Either a single 0x0d or 0x0d 0x0a pair + { + *s++ = '\n'; // replace first one with 0x0a + + if (*s == '\n') g.push(s, 1); + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (*s == 0) + { + return s; + } + else ++s; + } + } + }; + + PUGI__FN strconv_pcdata_t get_strconv_pcdata(unsigned int optmask) + { + PUGI__STATIC_ASSERT(parse_escapes == 0x10 && parse_eol == 0x20); + + switch ((optmask >> 4) & 3) // get bitmask for flags (eol escapes) + { + case 0: return strconv_pcdata_impl::parse; + case 1: return strconv_pcdata_impl::parse; + case 2: return strconv_pcdata_impl::parse; + case 3: return strconv_pcdata_impl::parse; + default: return 0; // should not get here + } + } + + typedef char_t* (*strconv_attribute_t)(char_t*, char_t); + + template struct strconv_attribute_impl + { + static char_t* parse_wnorm(char_t* s, char_t end_quote) + { + gap g; + + // trim leading whitespaces + if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + char_t* str = s; + + do ++str; + while (PUGI__IS_CHARTYPE(*str, ct_space)); + + g.push(s, str - s); + } + + while (true) + { + while (!PUGI__IS_CHARTYPE(*s, ct_parse_attr_ws | ct_space)) ++s; + + if (*s == end_quote) + { + char_t* str = g.flush(s); + + do *str-- = 0; + while (PUGI__IS_CHARTYPE(*str, ct_space)); + + return s + 1; + } + else if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + *s++ = ' '; + + if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + char_t* str = s + 1; + while (PUGI__IS_CHARTYPE(*str, ct_space)) ++str; + + g.push(s, str - s); + } + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + + static char_t* parse_wconv(char_t* s, char_t end_quote) + { + gap g; + + while (true) + { + while (!PUGI__IS_CHARTYPE(*s, ct_parse_attr_ws)) ++s; + + if (*s == end_quote) + { + *g.flush(s) = 0; + + return s + 1; + } + else if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + if (*s == '\r') + { + *s++ = ' '; + + if (*s == '\n') g.push(s, 1); + } + else *s++ = ' '; + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + + static char_t* parse_eol(char_t* s, char_t end_quote) + { + gap g; + + while (true) + { + while (!PUGI__IS_CHARTYPE(*s, ct_parse_attr)) ++s; + + if (*s == end_quote) + { + *g.flush(s) = 0; + + return s + 1; + } + else if (*s == '\r') + { + *s++ = '\n'; + + if (*s == '\n') g.push(s, 1); + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + + static char_t* parse_simple(char_t* s, char_t end_quote) + { + gap g; + + while (true) + { + while (!PUGI__IS_CHARTYPE(*s, ct_parse_attr)) ++s; + + if (*s == end_quote) + { + *g.flush(s) = 0; + + return s + 1; + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + }; + + PUGI__FN strconv_attribute_t get_strconv_attribute(unsigned int optmask) + { + PUGI__STATIC_ASSERT(parse_escapes == 0x10 && parse_eol == 0x20 && parse_wconv_attribute == 0x40 && parse_wnorm_attribute == 0x80); + + switch ((optmask >> 4) & 15) // get bitmask for flags (wconv wnorm eol escapes) + { + case 0: return strconv_attribute_impl::parse_simple; + case 1: return strconv_attribute_impl::parse_simple; + case 2: return strconv_attribute_impl::parse_eol; + case 3: return strconv_attribute_impl::parse_eol; + case 4: return strconv_attribute_impl::parse_wconv; + case 5: return strconv_attribute_impl::parse_wconv; + case 6: return strconv_attribute_impl::parse_wconv; + case 7: return strconv_attribute_impl::parse_wconv; + case 8: return strconv_attribute_impl::parse_wnorm; + case 9: return strconv_attribute_impl::parse_wnorm; + case 10: return strconv_attribute_impl::parse_wnorm; + case 11: return strconv_attribute_impl::parse_wnorm; + case 12: return strconv_attribute_impl::parse_wnorm; + case 13: return strconv_attribute_impl::parse_wnorm; + case 14: return strconv_attribute_impl::parse_wnorm; + case 15: return strconv_attribute_impl::parse_wnorm; + default: return 0; // should not get here + } + } + + inline xml_parse_result make_parse_result(xml_parse_status status, ptrdiff_t offset = 0) + { + xml_parse_result result; + result.status = status; + result.offset = offset; + + return result; + } + + struct xml_parser + { + xml_allocator alloc; + char_t* error_offset; + xml_parse_status error_status; + + // Parser utilities. + #define PUGI__SKIPWS() { while (PUGI__IS_CHARTYPE(*s, ct_space)) ++s; } + #define PUGI__OPTSET(OPT) ( optmsk & (OPT) ) + #define PUGI__PUSHNODE(TYPE) { cursor = append_node(cursor, alloc, TYPE); if (!cursor) PUGI__THROW_ERROR(status_out_of_memory, s); } + #define PUGI__POPNODE() { cursor = cursor->parent; } + #define PUGI__SCANFOR(X) { while (*s != 0 && !(X)) ++s; } + #define PUGI__SCANWHILE(X) { while ((X)) ++s; } + #define PUGI__ENDSEG() { ch = *s; *s = 0; ++s; } + #define PUGI__THROW_ERROR(err, m) return error_offset = m, error_status = err, static_cast(0) + #define PUGI__CHECK_ERROR(err, m) { if (*s == 0) PUGI__THROW_ERROR(err, m); } + + xml_parser(const xml_allocator& alloc_): alloc(alloc_), error_offset(0), error_status(status_ok) + { + } + + // DOCTYPE consists of nested sections of the following possible types: + // , , "...", '...' + // + // + // First group can not contain nested groups + // Second group can contain nested groups of the same type + // Third group can contain all other groups + char_t* parse_doctype_primitive(char_t* s) + { + if (*s == '"' || *s == '\'') + { + // quoted string + char_t ch = *s++; + PUGI__SCANFOR(*s == ch); + if (!*s) PUGI__THROW_ERROR(status_bad_doctype, s); + + s++; + } + else if (s[0] == '<' && s[1] == '?') + { + // + s += 2; + PUGI__SCANFOR(s[0] == '?' && s[1] == '>'); // no need for ENDSWITH because ?> can't terminate proper doctype + if (!*s) PUGI__THROW_ERROR(status_bad_doctype, s); + + s += 2; + } + else if (s[0] == '<' && s[1] == '!' && s[2] == '-' && s[3] == '-') + { + s += 4; + PUGI__SCANFOR(s[0] == '-' && s[1] == '-' && s[2] == '>'); // no need for ENDSWITH because --> can't terminate proper doctype + if (!*s) PUGI__THROW_ERROR(status_bad_doctype, s); + + s += 4; + } + else PUGI__THROW_ERROR(status_bad_doctype, s); + + return s; + } + + char_t* parse_doctype_ignore(char_t* s) + { + assert(s[0] == '<' && s[1] == '!' && s[2] == '['); + s++; + + while (*s) + { + if (s[0] == '<' && s[1] == '!' && s[2] == '[') + { + // nested ignore section + s = parse_doctype_ignore(s); + if (!s) return s; + } + else if (s[0] == ']' && s[1] == ']' && s[2] == '>') + { + // ignore section end + s += 3; + + return s; + } + else s++; + } + + PUGI__THROW_ERROR(status_bad_doctype, s); + } + + char_t* parse_doctype_group(char_t* s, char_t endch, bool toplevel) + { + assert(s[0] == '<' && s[1] == '!'); + s++; + + while (*s) + { + if (s[0] == '<' && s[1] == '!' && s[2] != '-') + { + if (s[2] == '[') + { + // ignore + s = parse_doctype_ignore(s); + if (!s) return s; + } + else + { + // some control group + s = parse_doctype_group(s, endch, false); + if (!s) return s; + } + } + else if (s[0] == '<' || s[0] == '"' || s[0] == '\'') + { + // unknown tag (forbidden), or some primitive group + s = parse_doctype_primitive(s); + if (!s) return s; + } + else if (*s == '>') + { + s++; + + return s; + } + else s++; + } + + if (!toplevel || endch != '>') PUGI__THROW_ERROR(status_bad_doctype, s); + + return s; + } + + char_t* parse_exclamation(char_t* s, xml_node_struct* cursor, unsigned int optmsk, char_t endch) + { + // parse node contents, starting with exclamation mark + ++s; + + if (*s == '-') // 'value = s; // Save the offset. + } + + if (PUGI__OPTSET(parse_eol) && PUGI__OPTSET(parse_comments)) + { + s = strconv_comment(s, endch); + + if (!s) PUGI__THROW_ERROR(status_bad_comment, cursor->value); + } + else + { + // Scan for terminating '-->'. + PUGI__SCANFOR(s[0] == '-' && s[1] == '-' && ENDSWITH(s[2], '>')); + PUGI__CHECK_ERROR(status_bad_comment, s); + + if (PUGI__OPTSET(parse_comments)) + *s = 0; // Zero-terminate this segment at the first terminating '-'. + + s += (s[2] == '>' ? 3 : 2); // Step over the '\0->'. + } + } + else PUGI__THROW_ERROR(status_bad_comment, s); + } + else if (*s == '[') + { + // 'value = s; // Save the offset. + + if (PUGI__OPTSET(parse_eol)) + { + s = strconv_cdata(s, endch); + + if (!s) PUGI__THROW_ERROR(status_bad_cdata, cursor->value); + } + else + { + // Scan for terminating ']]>'. + PUGI__SCANFOR(s[0] == ']' && s[1] == ']' && ENDSWITH(s[2], '>')); + PUGI__CHECK_ERROR(status_bad_cdata, s); + + *s++ = 0; // Zero-terminate this segment. + } + } + else // Flagged for discard, but we still have to scan for the terminator. + { + // Scan for terminating ']]>'. + PUGI__SCANFOR(s[0] == ']' && s[1] == ']' && ENDSWITH(s[2], '>')); + PUGI__CHECK_ERROR(status_bad_cdata, s); + + ++s; + } + + s += (s[1] == '>' ? 2 : 1); // Step over the last ']>'. + } + else PUGI__THROW_ERROR(status_bad_cdata, s); + } + else if (s[0] == 'D' && s[1] == 'O' && s[2] == 'C' && s[3] == 'T' && s[4] == 'Y' && s[5] == 'P' && ENDSWITH(s[6], 'E')) + { + s -= 2; + + if (cursor->parent) PUGI__THROW_ERROR(status_bad_doctype, s); + + char_t* mark = s + 9; + + s = parse_doctype_group(s, endch, true); + if (!s) return s; + + if (PUGI__OPTSET(parse_doctype)) + { + while (PUGI__IS_CHARTYPE(*mark, ct_space)) ++mark; + + PUGI__PUSHNODE(node_doctype); + + cursor->value = mark; + + assert((s[0] == 0 && endch == '>') || s[-1] == '>'); + s[*s == 0 ? 0 : -1] = 0; + + PUGI__POPNODE(); + } + } + else if (*s == 0 && endch == '-') PUGI__THROW_ERROR(status_bad_comment, s); + else if (*s == 0 && endch == '[') PUGI__THROW_ERROR(status_bad_cdata, s); + else PUGI__THROW_ERROR(status_unrecognized_tag, s); + + return s; + } + + char_t* parse_question(char_t* s, xml_node_struct*& ref_cursor, unsigned int optmsk, char_t endch) + { + // load into registers + xml_node_struct* cursor = ref_cursor; + char_t ch = 0; + + // parse node contents, starting with question mark + ++s; + + // read PI target + char_t* target = s; + + if (!PUGI__IS_CHARTYPE(*s, ct_start_symbol)) PUGI__THROW_ERROR(status_bad_pi, s); + + PUGI__SCANWHILE(PUGI__IS_CHARTYPE(*s, ct_symbol)); + PUGI__CHECK_ERROR(status_bad_pi, s); + + // determine node type; stricmp / strcasecmp is not portable + bool declaration = (target[0] | ' ') == 'x' && (target[1] | ' ') == 'm' && (target[2] | ' ') == 'l' && target + 3 == s; + + if (declaration ? PUGI__OPTSET(parse_declaration) : PUGI__OPTSET(parse_pi)) + { + if (declaration) + { + // disallow non top-level declarations + if (cursor->parent) PUGI__THROW_ERROR(status_bad_pi, s); + + PUGI__PUSHNODE(node_declaration); + } + else + { + PUGI__PUSHNODE(node_pi); + } + + cursor->name = target; + + PUGI__ENDSEG(); + + // parse value/attributes + if (ch == '?') + { + // empty node + if (!ENDSWITH(*s, '>')) PUGI__THROW_ERROR(status_bad_pi, s); + s += (*s == '>'); + + PUGI__POPNODE(); + } + else if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + PUGI__SKIPWS(); + + // scan for tag end + char_t* value = s; + + PUGI__SCANFOR(s[0] == '?' && ENDSWITH(s[1], '>')); + PUGI__CHECK_ERROR(status_bad_pi, s); + + if (declaration) + { + // replace ending ? with / so that 'element' terminates properly + *s = '/'; + + // we exit from this function with cursor at node_declaration, which is a signal to parse() to go to LOC_ATTRIBUTES + s = value; + } + else + { + // store value and step over > + cursor->value = value; + PUGI__POPNODE(); + + PUGI__ENDSEG(); + + s += (*s == '>'); + } + } + else PUGI__THROW_ERROR(status_bad_pi, s); + } + else + { + // scan for tag end + PUGI__SCANFOR(s[0] == '?' && ENDSWITH(s[1], '>')); + PUGI__CHECK_ERROR(status_bad_pi, s); + + s += (s[1] == '>' ? 2 : 1); + } + + // store from registers + ref_cursor = cursor; + + return s; + } + + char_t* parse(char_t* s, xml_node_struct* xmldoc, unsigned int optmsk, char_t endch) + { + strconv_attribute_t strconv_attribute = get_strconv_attribute(optmsk); + strconv_pcdata_t strconv_pcdata = get_strconv_pcdata(optmsk); + + char_t ch = 0; + xml_node_struct* cursor = xmldoc; + char_t* mark = s; + + while (*s != 0) + { + if (*s == '<') + { + ++s; + + LOC_TAG: + if (PUGI__IS_CHARTYPE(*s, ct_start_symbol)) // '<#...' + { + PUGI__PUSHNODE(node_element); // Append a new node to the tree. + + cursor->name = s; + + PUGI__SCANWHILE(PUGI__IS_CHARTYPE(*s, ct_symbol)); // Scan for a terminator. + PUGI__ENDSEG(); // Save char in 'ch', terminate & step over. + + if (ch == '>') + { + // end of tag + } + else if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + LOC_ATTRIBUTES: + while (true) + { + PUGI__SKIPWS(); // Eat any whitespace. + + if (PUGI__IS_CHARTYPE(*s, ct_start_symbol)) // <... #... + { + xml_attribute_struct* a = append_attribute_ll(cursor, alloc); // Make space for this attribute. + if (!a) PUGI__THROW_ERROR(status_out_of_memory, s); + + a->name = s; // Save the offset. + + PUGI__SCANWHILE(PUGI__IS_CHARTYPE(*s, ct_symbol)); // Scan for a terminator. + PUGI__CHECK_ERROR(status_bad_attribute, s); //$ redundant, left for performance + + PUGI__ENDSEG(); // Save char in 'ch', terminate & step over. + PUGI__CHECK_ERROR(status_bad_attribute, s); //$ redundant, left for performance + + if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + PUGI__SKIPWS(); // Eat any whitespace. + PUGI__CHECK_ERROR(status_bad_attribute, s); //$ redundant, left for performance + + ch = *s; + ++s; + } + + if (ch == '=') // '<... #=...' + { + PUGI__SKIPWS(); // Eat any whitespace. + + if (*s == '"' || *s == '\'') // '<... #="...' + { + ch = *s; // Save quote char to avoid breaking on "''" -or- '""'. + ++s; // Step over the quote. + a->value = s; // Save the offset. + + s = strconv_attribute(s, ch); + + if (!s) PUGI__THROW_ERROR(status_bad_attribute, a->value); + + // After this line the loop continues from the start; + // Whitespaces, / and > are ok, symbols and EOF are wrong, + // everything else will be detected + if (PUGI__IS_CHARTYPE(*s, ct_start_symbol)) PUGI__THROW_ERROR(status_bad_attribute, s); + } + else PUGI__THROW_ERROR(status_bad_attribute, s); + } + else PUGI__THROW_ERROR(status_bad_attribute, s); + } + else if (*s == '/') + { + ++s; + + if (*s == '>') + { + PUGI__POPNODE(); + s++; + break; + } + else if (*s == 0 && endch == '>') + { + PUGI__POPNODE(); + break; + } + else PUGI__THROW_ERROR(status_bad_start_element, s); + } + else if (*s == '>') + { + ++s; + + break; + } + else if (*s == 0 && endch == '>') + { + break; + } + else PUGI__THROW_ERROR(status_bad_start_element, s); + } + + // !!! + } + else if (ch == '/') // '<#.../' + { + if (!ENDSWITH(*s, '>')) PUGI__THROW_ERROR(status_bad_start_element, s); + + PUGI__POPNODE(); // Pop. + + s += (*s == '>'); + } + else if (ch == 0) + { + // we stepped over null terminator, backtrack & handle closing tag + --s; + + if (endch != '>') PUGI__THROW_ERROR(status_bad_start_element, s); + } + else PUGI__THROW_ERROR(status_bad_start_element, s); + } + else if (*s == '/') + { + ++s; + + char_t* name = cursor->name; + if (!name) PUGI__THROW_ERROR(status_end_element_mismatch, s); + + while (PUGI__IS_CHARTYPE(*s, ct_symbol)) + { + if (*s++ != *name++) PUGI__THROW_ERROR(status_end_element_mismatch, s); + } + + if (*name) + { + if (*s == 0 && name[0] == endch && name[1] == 0) PUGI__THROW_ERROR(status_bad_end_element, s); + else PUGI__THROW_ERROR(status_end_element_mismatch, s); + } + + PUGI__POPNODE(); // Pop. + + PUGI__SKIPWS(); + + if (*s == 0) + { + if (endch != '>') PUGI__THROW_ERROR(status_bad_end_element, s); + } + else + { + if (*s != '>') PUGI__THROW_ERROR(status_bad_end_element, s); + ++s; + } + } + else if (*s == '?') // 'header & xml_memory_page_type_mask) + 1 == node_declaration) goto LOC_ATTRIBUTES; + } + else if (*s == '!') // 'first_child) continue; + } + } + + s = mark; + + if (cursor->parent) + { + PUGI__PUSHNODE(node_pcdata); // Append a new node on the tree. + cursor->value = s; // Save the offset. + + s = strconv_pcdata(s); + + PUGI__POPNODE(); // Pop since this is a standalone. + + if (!*s) break; + } + else + { + PUGI__SCANFOR(*s == '<'); // '...<' + if (!*s) break; + + ++s; + } + + // We're after '<' + goto LOC_TAG; + } + } + + // check that last tag is closed + if (cursor != xmldoc) PUGI__THROW_ERROR(status_end_element_mismatch, s); + + return s; + } + + static xml_parse_result parse(char_t* buffer, size_t length, xml_node_struct* root, unsigned int optmsk) + { + xml_document_struct* xmldoc = static_cast(root); + + // store buffer for offset_debug + xmldoc->buffer = buffer; + + // early-out for empty documents + if (length == 0) return make_parse_result(status_ok); + + // create parser on stack + xml_parser parser(*xmldoc); + + // save last character and make buffer zero-terminated (speeds up parsing) + char_t endch = buffer[length - 1]; + buffer[length - 1] = 0; + + // perform actual parsing + parser.parse(buffer, xmldoc, optmsk, endch); + + xml_parse_result result = make_parse_result(parser.error_status, parser.error_offset ? parser.error_offset - buffer : 0); + assert(result.offset >= 0 && static_cast(result.offset) <= length); + + // update allocator state + *static_cast(xmldoc) = parser.alloc; + + // since we removed last character, we have to handle the only possible false positive + if (result && endch == '<') + { + // there's no possible well-formed document with < at the end + return make_parse_result(status_unrecognized_tag, length); + } + + return result; + } + }; + + // Output facilities + PUGI__FN xml_encoding get_write_native_encoding() + { + #ifdef PUGIXML_WCHAR_MODE + return get_wchar_encoding(); + #else + return encoding_utf8; + #endif + } + + PUGI__FN xml_encoding get_write_encoding(xml_encoding encoding) + { + // replace wchar encoding with utf implementation + if (encoding == encoding_wchar) return get_wchar_encoding(); + + // replace utf16 encoding with utf16 with specific endianness + if (encoding == encoding_utf16) return is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + // replace utf32 encoding with utf32 with specific endianness + if (encoding == encoding_utf32) return is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + // only do autodetection if no explicit encoding is requested + if (encoding != encoding_auto) return encoding; + + // assume utf8 encoding + return encoding_utf8; + } + +#ifdef PUGIXML_WCHAR_MODE + PUGI__FN size_t get_valid_length(const char_t* data, size_t length) + { + assert(length > 0); + + // discard last character if it's the lead of a surrogate pair + return (sizeof(wchar_t) == 2 && static_cast(static_cast(data[length - 1]) - 0xD800) < 0x400) ? length - 1 : length; + } + + PUGI__FN size_t convert_buffer(char_t* r_char, uint8_t* r_u8, uint16_t* r_u16, uint32_t* r_u32, const char_t* data, size_t length, xml_encoding encoding) + { + // only endian-swapping is required + if (need_endian_swap_utf(encoding, get_wchar_encoding())) + { + convert_wchar_endian_swap(r_char, data, length); + + return length * sizeof(char_t); + } + + // convert to utf8 + if (encoding == encoding_utf8) + { + uint8_t* dest = r_u8; + uint8_t* end = utf_decoder::decode_wchar_block(data, length, dest); + + return static_cast(end - dest); + } + + // convert to utf16 + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + uint16_t* dest = r_u16; + + // convert to native utf16 + uint16_t* end = utf_decoder::decode_wchar_block(data, length, dest); + + // swap if necessary + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + if (native_encoding != encoding) convert_utf_endian_swap(dest, dest, static_cast(end - dest)); + + return static_cast(end - dest) * sizeof(uint16_t); + } + + // convert to utf32 + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + uint32_t* dest = r_u32; + + // convert to native utf32 + uint32_t* end = utf_decoder::decode_wchar_block(data, length, dest); + + // swap if necessary + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + if (native_encoding != encoding) convert_utf_endian_swap(dest, dest, static_cast(end - dest)); + + return static_cast(end - dest) * sizeof(uint32_t); + } + + // convert to latin1 + if (encoding == encoding_latin1) + { + uint8_t* dest = r_u8; + uint8_t* end = utf_decoder::decode_wchar_block(data, length, dest); + + return static_cast(end - dest); + } + + assert(!"Invalid encoding"); + return 0; + } +#else + PUGI__FN size_t get_valid_length(const char_t* data, size_t length) + { + assert(length > 4); + + for (size_t i = 1; i <= 4; ++i) + { + uint8_t ch = static_cast(data[length - i]); + + // either a standalone character or a leading one + if ((ch & 0xc0) != 0x80) return length - i; + } + + // there are four non-leading characters at the end, sequence tail is broken so might as well process the whole chunk + return length; + } + + PUGI__FN size_t convert_buffer(char_t* /* r_char */, uint8_t* r_u8, uint16_t* r_u16, uint32_t* r_u32, const char_t* data, size_t length, xml_encoding encoding) + { + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + uint16_t* dest = r_u16; + + // convert to native utf16 + uint16_t* end = utf_decoder::decode_utf8_block(reinterpret_cast(data), length, dest); + + // swap if necessary + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + if (native_encoding != encoding) convert_utf_endian_swap(dest, dest, static_cast(end - dest)); + + return static_cast(end - dest) * sizeof(uint16_t); + } + + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + uint32_t* dest = r_u32; + + // convert to native utf32 + uint32_t* end = utf_decoder::decode_utf8_block(reinterpret_cast(data), length, dest); + + // swap if necessary + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + if (native_encoding != encoding) convert_utf_endian_swap(dest, dest, static_cast(end - dest)); + + return static_cast(end - dest) * sizeof(uint32_t); + } + + if (encoding == encoding_latin1) + { + uint8_t* dest = r_u8; + uint8_t* end = utf_decoder::decode_utf8_block(reinterpret_cast(data), length, dest); + + return static_cast(end - dest); + } + + assert(!"Invalid encoding"); + return 0; + } +#endif + + class xml_buffered_writer + { + xml_buffered_writer(const xml_buffered_writer&); + xml_buffered_writer& operator=(const xml_buffered_writer&); + + public: + xml_buffered_writer(xml_writer& writer_, xml_encoding user_encoding): writer(writer_), bufsize(0), encoding(get_write_encoding(user_encoding)) + { + PUGI__STATIC_ASSERT(bufcapacity >= 8); + } + + ~xml_buffered_writer() + { + flush(); + } + + void flush() + { + flush(buffer, bufsize); + bufsize = 0; + } + + void flush(const char_t* data, size_t size) + { + if (size == 0) return; + + // fast path, just write data + if (encoding == get_write_native_encoding()) + writer.write(data, size * sizeof(char_t)); + else + { + // convert chunk + size_t result = convert_buffer(scratch.data_char, scratch.data_u8, scratch.data_u16, scratch.data_u32, data, size, encoding); + assert(result <= sizeof(scratch)); + + // write data + writer.write(scratch.data_u8, result); + } + } + + void write(const char_t* data, size_t length) + { + if (bufsize + length > bufcapacity) + { + // flush the remaining buffer contents + flush(); + + // handle large chunks + if (length > bufcapacity) + { + if (encoding == get_write_native_encoding()) + { + // fast path, can just write data chunk + writer.write(data, length * sizeof(char_t)); + return; + } + + // need to convert in suitable chunks + while (length > bufcapacity) + { + // get chunk size by selecting such number of characters that are guaranteed to fit into scratch buffer + // and form a complete codepoint sequence (i.e. discard start of last codepoint if necessary) + size_t chunk_size = get_valid_length(data, bufcapacity); + + // convert chunk and write + flush(data, chunk_size); + + // iterate + data += chunk_size; + length -= chunk_size; + } + + // small tail is copied below + bufsize = 0; + } + } + + memcpy(buffer + bufsize, data, length * sizeof(char_t)); + bufsize += length; + } + + void write(const char_t* data) + { + write(data, strlength(data)); + } + + void write(char_t d0) + { + if (bufsize + 1 > bufcapacity) flush(); + + buffer[bufsize + 0] = d0; + bufsize += 1; + } + + void write(char_t d0, char_t d1) + { + if (bufsize + 2 > bufcapacity) flush(); + + buffer[bufsize + 0] = d0; + buffer[bufsize + 1] = d1; + bufsize += 2; + } + + void write(char_t d0, char_t d1, char_t d2) + { + if (bufsize + 3 > bufcapacity) flush(); + + buffer[bufsize + 0] = d0; + buffer[bufsize + 1] = d1; + buffer[bufsize + 2] = d2; + bufsize += 3; + } + + void write(char_t d0, char_t d1, char_t d2, char_t d3) + { + if (bufsize + 4 > bufcapacity) flush(); + + buffer[bufsize + 0] = d0; + buffer[bufsize + 1] = d1; + buffer[bufsize + 2] = d2; + buffer[bufsize + 3] = d3; + bufsize += 4; + } + + void write(char_t d0, char_t d1, char_t d2, char_t d3, char_t d4) + { + if (bufsize + 5 > bufcapacity) flush(); + + buffer[bufsize + 0] = d0; + buffer[bufsize + 1] = d1; + buffer[bufsize + 2] = d2; + buffer[bufsize + 3] = d3; + buffer[bufsize + 4] = d4; + bufsize += 5; + } + + void write(char_t d0, char_t d1, char_t d2, char_t d3, char_t d4, char_t d5) + { + if (bufsize + 6 > bufcapacity) flush(); + + buffer[bufsize + 0] = d0; + buffer[bufsize + 1] = d1; + buffer[bufsize + 2] = d2; + buffer[bufsize + 3] = d3; + buffer[bufsize + 4] = d4; + buffer[bufsize + 5] = d5; + bufsize += 6; + } + + // utf8 maximum expansion: x4 (-> utf32) + // utf16 maximum expansion: x2 (-> utf32) + // utf32 maximum expansion: x1 + enum + { + bufcapacitybytes = + #ifdef PUGIXML_MEMORY_OUTPUT_STACK + PUGIXML_MEMORY_OUTPUT_STACK + #else + 10240 + #endif + , + bufcapacity = bufcapacitybytes / (sizeof(char_t) + 4) + }; + + char_t buffer[bufcapacity]; + + union + { + uint8_t data_u8[4 * bufcapacity]; + uint16_t data_u16[2 * bufcapacity]; + uint32_t data_u32[bufcapacity]; + char_t data_char[bufcapacity]; + } scratch; + + xml_writer& writer; + size_t bufsize; + xml_encoding encoding; + }; + + PUGI__FN void text_output_escaped(xml_buffered_writer& writer, const char_t* s, chartypex_t type) + { + while (*s) + { + const char_t* prev = s; + + // While *s is a usual symbol + while (!PUGI__IS_CHARTYPEX(*s, type)) ++s; + + writer.write(prev, static_cast(s - prev)); + + switch (*s) + { + case 0: break; + case '&': + writer.write('&', 'a', 'm', 'p', ';'); + ++s; + break; + case '<': + writer.write('&', 'l', 't', ';'); + ++s; + break; + case '>': + writer.write('&', 'g', 't', ';'); + ++s; + break; + case '"': + writer.write('&', 'q', 'u', 'o', 't', ';'); + ++s; + break; + default: // s is not a usual symbol + { + unsigned int ch = static_cast(*s++); + assert(ch < 32); + + writer.write('&', '#', static_cast((ch / 10) + '0'), static_cast((ch % 10) + '0'), ';'); + } + } + } + } + + PUGI__FN void text_output(xml_buffered_writer& writer, const char_t* s, chartypex_t type, unsigned int flags) + { + if (flags & format_no_escapes) + writer.write(s); + else + text_output_escaped(writer, s, type); + } + + PUGI__FN void text_output_cdata(xml_buffered_writer& writer, const char_t* s) + { + do + { + writer.write('<', '!', '[', 'C', 'D'); + writer.write('A', 'T', 'A', '['); + + const char_t* prev = s; + + // look for ]]> sequence - we can't output it as is since it terminates CDATA + while (*s && !(s[0] == ']' && s[1] == ']' && s[2] == '>')) ++s; + + // skip ]] if we stopped at ]]>, > will go to the next CDATA section + if (*s) s += 2; + + writer.write(prev, static_cast(s - prev)); + + writer.write(']', ']', '>'); + } + while (*s); + } + + PUGI__FN void node_output_attributes(xml_buffered_writer& writer, const xml_node& node, unsigned int flags) + { + const char_t* default_name = PUGIXML_TEXT(":anonymous"); + + for (xml_attribute a = node.first_attribute(); a; a = a.next_attribute()) + { + writer.write(' '); + writer.write(a.name()[0] ? a.name() : default_name); + writer.write('=', '"'); + + text_output(writer, a.value(), ctx_special_attr, flags); + + writer.write('"'); + } + } + + PUGI__FN void node_output(xml_buffered_writer& writer, const xml_node& node, const char_t* indent, unsigned int flags, unsigned int depth) + { + const char_t* default_name = PUGIXML_TEXT(":anonymous"); + + if ((flags & format_indent) != 0 && (flags & format_raw) == 0) + for (unsigned int i = 0; i < depth; ++i) writer.write(indent); + + switch (node.type()) + { + case node_document: + { + for (xml_node n = node.first_child(); n; n = n.next_sibling()) + node_output(writer, n, indent, flags, depth); + break; + } + + case node_element: + { + const char_t* name = node.name()[0] ? node.name() : default_name; + + writer.write('<'); + writer.write(name); + + node_output_attributes(writer, node, flags); + + if (flags & format_raw) + { + if (!node.first_child()) + writer.write(' ', '/', '>'); + else + { + writer.write('>'); + + for (xml_node n = node.first_child(); n; n = n.next_sibling()) + node_output(writer, n, indent, flags, depth + 1); + + writer.write('<', '/'); + writer.write(name); + writer.write('>'); + } + } + else if (!node.first_child()) + writer.write(' ', '/', '>', '\n'); + else if (node.first_child() == node.last_child() && (node.first_child().type() == node_pcdata || node.first_child().type() == node_cdata)) + { + writer.write('>'); + + if (node.first_child().type() == node_pcdata) + text_output(writer, node.first_child().value(), ctx_special_pcdata, flags); + else + text_output_cdata(writer, node.first_child().value()); + + writer.write('<', '/'); + writer.write(name); + writer.write('>', '\n'); + } + else + { + writer.write('>', '\n'); + + for (xml_node n = node.first_child(); n; n = n.next_sibling()) + node_output(writer, n, indent, flags, depth + 1); + + if ((flags & format_indent) != 0 && (flags & format_raw) == 0) + for (unsigned int i = 0; i < depth; ++i) writer.write(indent); + + writer.write('<', '/'); + writer.write(name); + writer.write('>', '\n'); + } + + break; + } + + case node_pcdata: + text_output(writer, node.value(), ctx_special_pcdata, flags); + if ((flags & format_raw) == 0) writer.write('\n'); + break; + + case node_cdata: + text_output_cdata(writer, node.value()); + if ((flags & format_raw) == 0) writer.write('\n'); + break; + + case node_comment: + writer.write('<', '!', '-', '-'); + writer.write(node.value()); + writer.write('-', '-', '>'); + if ((flags & format_raw) == 0) writer.write('\n'); + break; + + case node_pi: + case node_declaration: + writer.write('<', '?'); + writer.write(node.name()[0] ? node.name() : default_name); + + if (node.type() == node_declaration) + { + node_output_attributes(writer, node, flags); + } + else if (node.value()[0]) + { + writer.write(' '); + writer.write(node.value()); + } + + writer.write('?', '>'); + if ((flags & format_raw) == 0) writer.write('\n'); + break; + + case node_doctype: + writer.write('<', '!', 'D', 'O', 'C'); + writer.write('T', 'Y', 'P', 'E'); + + if (node.value()[0]) + { + writer.write(' '); + writer.write(node.value()); + } + + writer.write('>'); + if ((flags & format_raw) == 0) writer.write('\n'); + break; + + default: + assert(!"Invalid node type"); + } + } + + inline bool has_declaration(const xml_node& node) + { + for (xml_node child = node.first_child(); child; child = child.next_sibling()) + { + xml_node_type type = child.type(); + + if (type == node_declaration) return true; + if (type == node_element) return false; + } + + return false; + } + + inline bool allow_insert_child(xml_node_type parent, xml_node_type child) + { + if (parent != node_document && parent != node_element) return false; + if (child == node_document || child == node_null) return false; + if (parent != node_document && (child == node_declaration || child == node_doctype)) return false; + + return true; + } + + PUGI__FN void recursive_copy_skip(xml_node& dest, const xml_node& source, const xml_node& skip) + { + assert(dest.type() == source.type()); + + switch (source.type()) + { + case node_element: + { + dest.set_name(source.name()); + + for (xml_attribute a = source.first_attribute(); a; a = a.next_attribute()) + dest.append_attribute(a.name()).set_value(a.value()); + + for (xml_node c = source.first_child(); c; c = c.next_sibling()) + { + if (c == skip) continue; + + xml_node cc = dest.append_child(c.type()); + assert(cc); + + recursive_copy_skip(cc, c, skip); + } + + break; + } + + case node_pcdata: + case node_cdata: + case node_comment: + case node_doctype: + dest.set_value(source.value()); + break; + + case node_pi: + dest.set_name(source.name()); + dest.set_value(source.value()); + break; + + case node_declaration: + { + dest.set_name(source.name()); + + for (xml_attribute a = source.first_attribute(); a; a = a.next_attribute()) + dest.append_attribute(a.name()).set_value(a.value()); + + break; + } + + default: + assert(!"Invalid node type"); + } + } + + inline bool is_text_node(xml_node_struct* node) + { + xml_node_type type = static_cast((node->header & impl::xml_memory_page_type_mask) + 1); + + return type == node_pcdata || type == node_cdata; + } + + // get value with conversion functions + PUGI__FN int get_value_int(const char_t* value, int def) + { + if (!value) return def; + + #ifdef PUGIXML_WCHAR_MODE + return static_cast(wcstol(value, 0, 10)); + #else + return static_cast(strtol(value, 0, 10)); + #endif + } + + PUGI__FN unsigned int get_value_uint(const char_t* value, unsigned int def) + { + if (!value) return def; + + #ifdef PUGIXML_WCHAR_MODE + return static_cast(wcstoul(value, 0, 10)); + #else + return static_cast(strtoul(value, 0, 10)); + #endif + } + + PUGI__FN double get_value_double(const char_t* value, double def) + { + if (!value) return def; + + #ifdef PUGIXML_WCHAR_MODE + return wcstod(value, 0); + #else + return strtod(value, 0); + #endif + } + + PUGI__FN float get_value_float(const char_t* value, float def) + { + if (!value) return def; + + #ifdef PUGIXML_WCHAR_MODE + return static_cast(wcstod(value, 0)); + #else + return static_cast(strtod(value, 0)); + #endif + } + + PUGI__FN bool get_value_bool(const char_t* value, bool def) + { + if (!value) return def; + + // only look at first char + char_t first = *value; + + // 1*, t* (true), T* (True), y* (yes), Y* (YES) + return (first == '1' || first == 't' || first == 'T' || first == 'y' || first == 'Y'); + } + + // set value with conversion functions + PUGI__FN bool set_value_buffer(char_t*& dest, uintptr_t& header, uintptr_t header_mask, char (&buf)[128]) + { + #ifdef PUGIXML_WCHAR_MODE + char_t wbuf[128]; + impl::widen_ascii(wbuf, buf); + + return strcpy_insitu(dest, header, header_mask, wbuf); + #else + return strcpy_insitu(dest, header, header_mask, buf); + #endif + } + + PUGI__FN bool set_value_convert(char_t*& dest, uintptr_t& header, uintptr_t header_mask, int value) + { + char buf[128]; + sprintf(buf, "%d", value); + + return set_value_buffer(dest, header, header_mask, buf); + } + + PUGI__FN bool set_value_convert(char_t*& dest, uintptr_t& header, uintptr_t header_mask, unsigned int value) + { + char buf[128]; + sprintf(buf, "%u", value); + + return set_value_buffer(dest, header, header_mask, buf); + } + + PUGI__FN bool set_value_convert(char_t*& dest, uintptr_t& header, uintptr_t header_mask, double value) + { + char buf[128]; + sprintf(buf, "%g", value); + + return set_value_buffer(dest, header, header_mask, buf); + } + + PUGI__FN bool set_value_convert(char_t*& dest, uintptr_t& header, uintptr_t header_mask, bool value) + { + return strcpy_insitu(dest, header, header_mask, value ? PUGIXML_TEXT("true") : PUGIXML_TEXT("false")); + } + + // we need to get length of entire file to load it in memory; the only (relatively) sane way to do it is via seek/tell trick + PUGI__FN xml_parse_status get_file_size(FILE* file, size_t& out_result) + { + #if defined(PUGI__MSVC_CRT_VERSION) && PUGI__MSVC_CRT_VERSION >= 1400 && !defined(_WIN32_WCE) + // there are 64-bit versions of fseek/ftell, let's use them + typedef __int64 length_type; + + _fseeki64(file, 0, SEEK_END); + length_type length = _ftelli64(file); + _fseeki64(file, 0, SEEK_SET); + #elif defined(__MINGW32__) && !defined(__NO_MINGW_LFS) && !defined(__STRICT_ANSI__) + // there are 64-bit versions of fseek/ftell, let's use them + typedef off64_t length_type; + + fseeko64(file, 0, SEEK_END); + length_type length = ftello64(file); + fseeko64(file, 0, SEEK_SET); + #else + // if this is a 32-bit OS, long is enough; if this is a unix system, long is 64-bit, which is enough; otherwise we can't do anything anyway. + typedef long length_type; + + fseek(file, 0, SEEK_END); + length_type length = ftell(file); + fseek(file, 0, SEEK_SET); + #endif + + // check for I/O errors + if (length < 0) return status_io_error; + + // check for overflow + size_t result = static_cast(length); + + if (static_cast(result) != length) return status_out_of_memory; + + // finalize + out_result = result; + + return status_ok; + } + + PUGI__FN xml_parse_result load_file_impl(xml_document& doc, FILE* file, unsigned int options, xml_encoding encoding) + { + if (!file) return make_parse_result(status_file_not_found); + + // get file size (can result in I/O errors) + size_t size = 0; + xml_parse_status size_status = get_file_size(file, size); + + if (size_status != status_ok) + { + fclose(file); + return make_parse_result(size_status); + } + + // allocate buffer for the whole file + char* contents = static_cast(xml_memory::allocate(size > 0 ? size : 1)); + + if (!contents) + { + fclose(file); + return make_parse_result(status_out_of_memory); + } + + // read file in memory + size_t read_size = fread(contents, 1, size, file); + fclose(file); + + if (read_size != size) + { + xml_memory::deallocate(contents); + return make_parse_result(status_io_error); + } + + return doc.load_buffer_inplace_own(contents, size, options, encoding); + } + +#ifndef PUGIXML_NO_STL + template struct xml_stream_chunk + { + static xml_stream_chunk* create() + { + void* memory = xml_memory::allocate(sizeof(xml_stream_chunk)); + + return new (memory) xml_stream_chunk(); + } + + static void destroy(void* ptr) + { + xml_stream_chunk* chunk = static_cast(ptr); + + // free chunk chain + while (chunk) + { + xml_stream_chunk* next = chunk->next; + xml_memory::deallocate(chunk); + chunk = next; + } + } + + xml_stream_chunk(): next(0), size(0) + { + } + + xml_stream_chunk* next; + size_t size; + + T data[xml_memory_page_size / sizeof(T)]; + }; + + template PUGI__FN xml_parse_status load_stream_data_noseek(std::basic_istream& stream, void** out_buffer, size_t* out_size) + { + buffer_holder chunks(0, xml_stream_chunk::destroy); + + // read file to a chunk list + size_t total = 0; + xml_stream_chunk* last = 0; + + while (!stream.eof()) + { + // allocate new chunk + xml_stream_chunk* chunk = xml_stream_chunk::create(); + if (!chunk) return status_out_of_memory; + + // append chunk to list + if (last) last = last->next = chunk; + else chunks.data = last = chunk; + + // read data to chunk + stream.read(chunk->data, static_cast(sizeof(chunk->data) / sizeof(T))); + chunk->size = static_cast(stream.gcount()) * sizeof(T); + + // read may set failbit | eofbit in case gcount() is less than read length, so check for other I/O errors + if (stream.bad() || (!stream.eof() && stream.fail())) return status_io_error; + + // guard against huge files (chunk size is small enough to make this overflow check work) + if (total + chunk->size < total) return status_out_of_memory; + total += chunk->size; + } + + // copy chunk list to a contiguous buffer + char* buffer = static_cast(xml_memory::allocate(total)); + if (!buffer) return status_out_of_memory; + + char* write = buffer; + + for (xml_stream_chunk* chunk = static_cast*>(chunks.data); chunk; chunk = chunk->next) + { + assert(write + chunk->size <= buffer + total); + memcpy(write, chunk->data, chunk->size); + write += chunk->size; + } + + assert(write == buffer + total); + + // return buffer + *out_buffer = buffer; + *out_size = total; + + return status_ok; + } + + template PUGI__FN xml_parse_status load_stream_data_seek(std::basic_istream& stream, void** out_buffer, size_t* out_size) + { + // get length of remaining data in stream + typename std::basic_istream::pos_type pos = stream.tellg(); + stream.seekg(0, std::ios::end); + std::streamoff length = stream.tellg() - pos; + stream.seekg(pos); + + if (stream.fail() || pos < 0) return status_io_error; + + // guard against huge files + size_t read_length = static_cast(length); + + if (static_cast(read_length) != length || length < 0) return status_out_of_memory; + + // read stream data into memory (guard against stream exceptions with buffer holder) + buffer_holder buffer(xml_memory::allocate((read_length > 0 ? read_length : 1) * sizeof(T)), xml_memory::deallocate); + if (!buffer.data) return status_out_of_memory; + + stream.read(static_cast(buffer.data), static_cast(read_length)); + + // read may set failbit | eofbit in case gcount() is less than read_length (i.e. line ending conversion), so check for other I/O errors + if (stream.bad() || (!stream.eof() && stream.fail())) return status_io_error; + + // return buffer + size_t actual_length = static_cast(stream.gcount()); + assert(actual_length <= read_length); + + *out_buffer = buffer.release(); + *out_size = actual_length * sizeof(T); + + return status_ok; + } + + template PUGI__FN xml_parse_result load_stream_impl(xml_document& doc, std::basic_istream& stream, unsigned int options, xml_encoding encoding) + { + void* buffer = 0; + size_t size = 0; + + // load stream to memory (using seek-based implementation if possible, since it's faster and takes less memory) + xml_parse_status status = (stream.tellg() < 0) ? load_stream_data_noseek(stream, &buffer, &size) : load_stream_data_seek(stream, &buffer, &size); + if (status != status_ok) return make_parse_result(status); + + return doc.load_buffer_inplace_own(buffer, size, options, encoding); + } +#endif + +#if defined(PUGI__MSVC_CRT_VERSION) || defined(__BORLANDC__) || (defined(__MINGW32__) && !defined(__STRICT_ANSI__)) + PUGI__FN FILE* open_file_wide(const wchar_t* path, const wchar_t* mode) + { + return _wfopen(path, mode); + } +#else + PUGI__FN char* convert_path_heap(const wchar_t* str) + { + assert(str); + + // first pass: get length in utf8 characters + size_t length = wcslen(str); + size_t size = as_utf8_begin(str, length); + + // allocate resulting string + char* result = static_cast(xml_memory::allocate(size + 1)); + if (!result) return 0; + + // second pass: convert to utf8 + as_utf8_end(result, size, str, length); + + return result; + } + + PUGI__FN FILE* open_file_wide(const wchar_t* path, const wchar_t* mode) + { + // there is no standard function to open wide paths, so our best bet is to try utf8 path + char* path_utf8 = convert_path_heap(path); + if (!path_utf8) return 0; + + // convert mode to ASCII (we mirror _wfopen interface) + char mode_ascii[4] = {0}; + for (size_t i = 0; mode[i]; ++i) mode_ascii[i] = static_cast(mode[i]); + + // try to open the utf8 path + FILE* result = fopen(path_utf8, mode_ascii); + + // free dummy buffer + xml_memory::deallocate(path_utf8); + + return result; + } +#endif + + PUGI__FN bool save_file_impl(const xml_document& doc, FILE* file, const char_t* indent, unsigned int flags, xml_encoding encoding) + { + if (!file) return false; + + xml_writer_file writer(file); + doc.save(writer, indent, flags, encoding); + + int result = ferror(file); + + fclose(file); + + return result == 0; + } +PUGI__NS_END + +namespace pugi +{ + PUGI__FN xml_writer_file::xml_writer_file(void* file_): file(file_) + { + } + + PUGI__FN void xml_writer_file::write(const void* data, size_t size) + { + size_t result = fwrite(data, 1, size, static_cast(file)); + (void)!result; // unfortunately we can't do proper error handling here + } + +#ifndef PUGIXML_NO_STL + PUGI__FN xml_writer_stream::xml_writer_stream(std::basic_ostream >& stream): narrow_stream(&stream), wide_stream(0) + { + } + + PUGI__FN xml_writer_stream::xml_writer_stream(std::basic_ostream >& stream): narrow_stream(0), wide_stream(&stream) + { + } + + PUGI__FN void xml_writer_stream::write(const void* data, size_t size) + { + if (narrow_stream) + { + assert(!wide_stream); + narrow_stream->write(reinterpret_cast(data), static_cast(size)); + } + else + { + assert(wide_stream); + assert(size % sizeof(wchar_t) == 0); + + wide_stream->write(reinterpret_cast(data), static_cast(size / sizeof(wchar_t))); + } + } +#endif + + PUGI__FN xml_tree_walker::xml_tree_walker(): _depth(0) + { + } + + PUGI__FN xml_tree_walker::~xml_tree_walker() + { + } + + PUGI__FN int xml_tree_walker::depth() const + { + return _depth; + } + + PUGI__FN bool xml_tree_walker::begin(xml_node&) + { + return true; + } + + PUGI__FN bool xml_tree_walker::end(xml_node&) + { + return true; + } + + PUGI__FN xml_attribute::xml_attribute(): _attr(0) + { + } + + PUGI__FN xml_attribute::xml_attribute(xml_attribute_struct* attr): _attr(attr) + { + } + + PUGI__FN static void unspecified_bool_xml_attribute(xml_attribute***) + { + } + + PUGI__FN xml_attribute::operator xml_attribute::unspecified_bool_type() const + { + return _attr ? unspecified_bool_xml_attribute : 0; + } + + PUGI__FN bool xml_attribute::operator!() const + { + return !_attr; + } + + PUGI__FN bool xml_attribute::operator==(const xml_attribute& r) const + { + return (_attr == r._attr); + } + + PUGI__FN bool xml_attribute::operator!=(const xml_attribute& r) const + { + return (_attr != r._attr); + } + + PUGI__FN bool xml_attribute::operator<(const xml_attribute& r) const + { + return (_attr < r._attr); + } + + PUGI__FN bool xml_attribute::operator>(const xml_attribute& r) const + { + return (_attr > r._attr); + } + + PUGI__FN bool xml_attribute::operator<=(const xml_attribute& r) const + { + return (_attr <= r._attr); + } + + PUGI__FN bool xml_attribute::operator>=(const xml_attribute& r) const + { + return (_attr >= r._attr); + } + + PUGI__FN xml_attribute xml_attribute::next_attribute() const + { + return _attr ? xml_attribute(_attr->next_attribute) : xml_attribute(); + } + + PUGI__FN xml_attribute xml_attribute::previous_attribute() const + { + return _attr && _attr->prev_attribute_c->next_attribute ? xml_attribute(_attr->prev_attribute_c) : xml_attribute(); + } + + PUGI__FN const char_t* xml_attribute::as_string(const char_t* def) const + { + return (_attr && _attr->value) ? _attr->value : def; + } + + PUGI__FN int xml_attribute::as_int(int def) const + { + return impl::get_value_int(_attr ? _attr->value : 0, def); + } + + PUGI__FN unsigned int xml_attribute::as_uint(unsigned int def) const + { + return impl::get_value_uint(_attr ? _attr->value : 0, def); + } + + PUGI__FN double xml_attribute::as_double(double def) const + { + return impl::get_value_double(_attr ? _attr->value : 0, def); + } + + PUGI__FN float xml_attribute::as_float(float def) const + { + return impl::get_value_float(_attr ? _attr->value : 0, def); + } + + PUGI__FN bool xml_attribute::as_bool(bool def) const + { + return impl::get_value_bool(_attr ? _attr->value : 0, def); + } + + PUGI__FN bool xml_attribute::empty() const + { + return !_attr; + } + + PUGI__FN const char_t* xml_attribute::name() const + { + return (_attr && _attr->name) ? _attr->name : PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* xml_attribute::value() const + { + return (_attr && _attr->value) ? _attr->value : PUGIXML_TEXT(""); + } + + PUGI__FN size_t xml_attribute::hash_value() const + { + return static_cast(reinterpret_cast(_attr) / sizeof(xml_attribute_struct)); + } + + PUGI__FN xml_attribute_struct* xml_attribute::internal_object() const + { + return _attr; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(const char_t* rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(int rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(unsigned int rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(double rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(bool rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN bool xml_attribute::set_name(const char_t* rhs) + { + if (!_attr) return false; + + return impl::strcpy_insitu(_attr->name, _attr->header, impl::xml_memory_page_name_allocated_mask, rhs); + } + + PUGI__FN bool xml_attribute::set_value(const char_t* rhs) + { + if (!_attr) return false; + + return impl::strcpy_insitu(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs); + } + + PUGI__FN bool xml_attribute::set_value(int rhs) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs); + } + + PUGI__FN bool xml_attribute::set_value(unsigned int rhs) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs); + } + + PUGI__FN bool xml_attribute::set_value(double rhs) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs); + } + + PUGI__FN bool xml_attribute::set_value(bool rhs) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs); + } + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xml_attribute& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xml_attribute& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN xml_node::xml_node(): _root(0) + { + } + + PUGI__FN xml_node::xml_node(xml_node_struct* p): _root(p) + { + } + + PUGI__FN static void unspecified_bool_xml_node(xml_node***) + { + } + + PUGI__FN xml_node::operator xml_node::unspecified_bool_type() const + { + return _root ? unspecified_bool_xml_node : 0; + } + + PUGI__FN bool xml_node::operator!() const + { + return !_root; + } + + PUGI__FN xml_node::iterator xml_node::begin() const + { + return iterator(_root ? _root->first_child : 0, _root); + } + + PUGI__FN xml_node::iterator xml_node::end() const + { + return iterator(0, _root); + } + + PUGI__FN xml_node::attribute_iterator xml_node::attributes_begin() const + { + return attribute_iterator(_root ? _root->first_attribute : 0, _root); + } + + PUGI__FN xml_node::attribute_iterator xml_node::attributes_end() const + { + return attribute_iterator(0, _root); + } + + PUGI__FN xml_object_range xml_node::children() const + { + return xml_object_range(begin(), end()); + } + + PUGI__FN xml_object_range xml_node::children(const char_t* name_) const + { + return xml_object_range(xml_named_node_iterator(child(name_), name_), xml_named_node_iterator()); + } + + PUGI__FN xml_object_range xml_node::attributes() const + { + return xml_object_range(attributes_begin(), attributes_end()); + } + + PUGI__FN bool xml_node::operator==(const xml_node& r) const + { + return (_root == r._root); + } + + PUGI__FN bool xml_node::operator!=(const xml_node& r) const + { + return (_root != r._root); + } + + PUGI__FN bool xml_node::operator<(const xml_node& r) const + { + return (_root < r._root); + } + + PUGI__FN bool xml_node::operator>(const xml_node& r) const + { + return (_root > r._root); + } + + PUGI__FN bool xml_node::operator<=(const xml_node& r) const + { + return (_root <= r._root); + } + + PUGI__FN bool xml_node::operator>=(const xml_node& r) const + { + return (_root >= r._root); + } + + PUGI__FN bool xml_node::empty() const + { + return !_root; + } + + PUGI__FN const char_t* xml_node::name() const + { + return (_root && _root->name) ? _root->name : PUGIXML_TEXT(""); + } + + PUGI__FN xml_node_type xml_node::type() const + { + return _root ? static_cast((_root->header & impl::xml_memory_page_type_mask) + 1) : node_null; + } + + PUGI__FN const char_t* xml_node::value() const + { + return (_root && _root->value) ? _root->value : PUGIXML_TEXT(""); + } + + PUGI__FN xml_node xml_node::child(const char_t* name_) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if (i->name && impl::strequal(name_, i->name)) return xml_node(i); + + return xml_node(); + } + + PUGI__FN xml_attribute xml_node::attribute(const char_t* name_) const + { + if (!_root) return xml_attribute(); + + for (xml_attribute_struct* i = _root->first_attribute; i; i = i->next_attribute) + if (i->name && impl::strequal(name_, i->name)) + return xml_attribute(i); + + return xml_attribute(); + } + + PUGI__FN xml_node xml_node::next_sibling(const char_t* name_) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->next_sibling; i; i = i->next_sibling) + if (i->name && impl::strequal(name_, i->name)) return xml_node(i); + + return xml_node(); + } + + PUGI__FN xml_node xml_node::next_sibling() const + { + if (!_root) return xml_node(); + + if (_root->next_sibling) return xml_node(_root->next_sibling); + else return xml_node(); + } + + PUGI__FN xml_node xml_node::previous_sibling(const char_t* name_) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->prev_sibling_c; i->next_sibling; i = i->prev_sibling_c) + if (i->name && impl::strequal(name_, i->name)) return xml_node(i); + + return xml_node(); + } + + PUGI__FN xml_node xml_node::previous_sibling() const + { + if (!_root) return xml_node(); + + if (_root->prev_sibling_c->next_sibling) return xml_node(_root->prev_sibling_c); + else return xml_node(); + } + + PUGI__FN xml_node xml_node::parent() const + { + return _root ? xml_node(_root->parent) : xml_node(); + } + + PUGI__FN xml_node xml_node::root() const + { + if (!_root) return xml_node(); + + impl::xml_memory_page* page = reinterpret_cast(_root->header & impl::xml_memory_page_pointer_mask); + + return xml_node(static_cast(page->allocator)); + } + + PUGI__FN xml_text xml_node::text() const + { + return xml_text(_root); + } + + PUGI__FN const char_t* xml_node::child_value() const + { + if (!_root) return PUGIXML_TEXT(""); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if (i->value && impl::is_text_node(i)) + return i->value; + + return PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* xml_node::child_value(const char_t* name_) const + { + return child(name_).child_value(); + } + + PUGI__FN xml_attribute xml_node::first_attribute() const + { + return _root ? xml_attribute(_root->first_attribute) : xml_attribute(); + } + + PUGI__FN xml_attribute xml_node::last_attribute() const + { + return _root && _root->first_attribute ? xml_attribute(_root->first_attribute->prev_attribute_c) : xml_attribute(); + } + + PUGI__FN xml_node xml_node::first_child() const + { + return _root ? xml_node(_root->first_child) : xml_node(); + } + + PUGI__FN xml_node xml_node::last_child() const + { + return _root && _root->first_child ? xml_node(_root->first_child->prev_sibling_c) : xml_node(); + } + + PUGI__FN bool xml_node::set_name(const char_t* rhs) + { + switch (type()) + { + case node_pi: + case node_declaration: + case node_element: + return impl::strcpy_insitu(_root->name, _root->header, impl::xml_memory_page_name_allocated_mask, rhs); + + default: + return false; + } + } + + PUGI__FN bool xml_node::set_value(const char_t* rhs) + { + switch (type()) + { + case node_pi: + case node_cdata: + case node_pcdata: + case node_comment: + case node_doctype: + return impl::strcpy_insitu(_root->value, _root->header, impl::xml_memory_page_value_allocated_mask, rhs); + + default: + return false; + } + } + + PUGI__FN xml_attribute xml_node::append_attribute(const char_t* name_) + { + if (type() != node_element && type() != node_declaration) return xml_attribute(); + + xml_attribute a(impl::append_attribute_ll(_root, impl::get_allocator(_root))); + a.set_name(name_); + + return a; + } + + PUGI__FN xml_attribute xml_node::prepend_attribute(const char_t* name_) + { + if (type() != node_element && type() != node_declaration) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(impl::get_allocator(_root))); + if (!a) return xml_attribute(); + + a.set_name(name_); + + xml_attribute_struct* head = _root->first_attribute; + + if (head) + { + a._attr->prev_attribute_c = head->prev_attribute_c; + head->prev_attribute_c = a._attr; + } + else + a._attr->prev_attribute_c = a._attr; + + a._attr->next_attribute = head; + _root->first_attribute = a._attr; + + return a; + } + + PUGI__FN xml_attribute xml_node::insert_attribute_before(const char_t* name_, const xml_attribute& attr) + { + if ((type() != node_element && type() != node_declaration) || attr.empty()) return xml_attribute(); + + // check that attribute belongs to *this + xml_attribute_struct* cur = attr._attr; + + while (cur->prev_attribute_c->next_attribute) cur = cur->prev_attribute_c; + + if (cur != _root->first_attribute) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(impl::get_allocator(_root))); + if (!a) return xml_attribute(); + + a.set_name(name_); + + if (attr._attr->prev_attribute_c->next_attribute) + attr._attr->prev_attribute_c->next_attribute = a._attr; + else + _root->first_attribute = a._attr; + + a._attr->prev_attribute_c = attr._attr->prev_attribute_c; + a._attr->next_attribute = attr._attr; + attr._attr->prev_attribute_c = a._attr; + + return a; + } + + PUGI__FN xml_attribute xml_node::insert_attribute_after(const char_t* name_, const xml_attribute& attr) + { + if ((type() != node_element && type() != node_declaration) || attr.empty()) return xml_attribute(); + + // check that attribute belongs to *this + xml_attribute_struct* cur = attr._attr; + + while (cur->prev_attribute_c->next_attribute) cur = cur->prev_attribute_c; + + if (cur != _root->first_attribute) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(impl::get_allocator(_root))); + if (!a) return xml_attribute(); + + a.set_name(name_); + + if (attr._attr->next_attribute) + attr._attr->next_attribute->prev_attribute_c = a._attr; + else + _root->first_attribute->prev_attribute_c = a._attr; + + a._attr->next_attribute = attr._attr->next_attribute; + a._attr->prev_attribute_c = attr._attr; + attr._attr->next_attribute = a._attr; + + return a; + } + + PUGI__FN xml_attribute xml_node::append_copy(const xml_attribute& proto) + { + if (!proto) return xml_attribute(); + + xml_attribute result = append_attribute(proto.name()); + result.set_value(proto.value()); + + return result; + } + + PUGI__FN xml_attribute xml_node::prepend_copy(const xml_attribute& proto) + { + if (!proto) return xml_attribute(); + + xml_attribute result = prepend_attribute(proto.name()); + result.set_value(proto.value()); + + return result; + } + + PUGI__FN xml_attribute xml_node::insert_copy_after(const xml_attribute& proto, const xml_attribute& attr) + { + if (!proto) return xml_attribute(); + + xml_attribute result = insert_attribute_after(proto.name(), attr); + result.set_value(proto.value()); + + return result; + } + + PUGI__FN xml_attribute xml_node::insert_copy_before(const xml_attribute& proto, const xml_attribute& attr) + { + if (!proto) return xml_attribute(); + + xml_attribute result = insert_attribute_before(proto.name(), attr); + result.set_value(proto.value()); + + return result; + } + + PUGI__FN xml_node xml_node::append_child(xml_node_type type_) + { + if (!impl::allow_insert_child(this->type(), type_)) return xml_node(); + + xml_node n(impl::append_node(_root, impl::get_allocator(_root), type_)); + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::prepend_child(xml_node_type type_) + { + if (!impl::allow_insert_child(this->type(), type_)) return xml_node(); + + xml_node n(impl::allocate_node(impl::get_allocator(_root), type_)); + if (!n) return xml_node(); + + n._root->parent = _root; + + xml_node_struct* head = _root->first_child; + + if (head) + { + n._root->prev_sibling_c = head->prev_sibling_c; + head->prev_sibling_c = n._root; + } + else + n._root->prev_sibling_c = n._root; + + n._root->next_sibling = head; + _root->first_child = n._root; + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::insert_child_before(xml_node_type type_, const xml_node& node) + { + if (!impl::allow_insert_child(this->type(), type_)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + + xml_node n(impl::allocate_node(impl::get_allocator(_root), type_)); + if (!n) return xml_node(); + + n._root->parent = _root; + + if (node._root->prev_sibling_c->next_sibling) + node._root->prev_sibling_c->next_sibling = n._root; + else + _root->first_child = n._root; + + n._root->prev_sibling_c = node._root->prev_sibling_c; + n._root->next_sibling = node._root; + node._root->prev_sibling_c = n._root; + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::insert_child_after(xml_node_type type_, const xml_node& node) + { + if (!impl::allow_insert_child(this->type(), type_)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + + xml_node n(impl::allocate_node(impl::get_allocator(_root), type_)); + if (!n) return xml_node(); + + n._root->parent = _root; + + if (node._root->next_sibling) + node._root->next_sibling->prev_sibling_c = n._root; + else + _root->first_child->prev_sibling_c = n._root; + + n._root->next_sibling = node._root->next_sibling; + n._root->prev_sibling_c = node._root; + node._root->next_sibling = n._root; + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::append_child(const char_t* name_) + { + xml_node result = append_child(node_element); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::prepend_child(const char_t* name_) + { + xml_node result = prepend_child(node_element); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::insert_child_after(const char_t* name_, const xml_node& node) + { + xml_node result = insert_child_after(node_element, node); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::insert_child_before(const char_t* name_, const xml_node& node) + { + xml_node result = insert_child_before(node_element, node); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::append_copy(const xml_node& proto) + { + xml_node result = append_child(proto.type()); + + if (result) impl::recursive_copy_skip(result, proto, result); + + return result; + } + + PUGI__FN xml_node xml_node::prepend_copy(const xml_node& proto) + { + xml_node result = prepend_child(proto.type()); + + if (result) impl::recursive_copy_skip(result, proto, result); + + return result; + } + + PUGI__FN xml_node xml_node::insert_copy_after(const xml_node& proto, const xml_node& node) + { + xml_node result = insert_child_after(proto.type(), node); + + if (result) impl::recursive_copy_skip(result, proto, result); + + return result; + } + + PUGI__FN xml_node xml_node::insert_copy_before(const xml_node& proto, const xml_node& node) + { + xml_node result = insert_child_before(proto.type(), node); + + if (result) impl::recursive_copy_skip(result, proto, result); + + return result; + } + + PUGI__FN bool xml_node::remove_attribute(const char_t* name_) + { + return remove_attribute(attribute(name_)); + } + + PUGI__FN bool xml_node::remove_attribute(const xml_attribute& a) + { + if (!_root || !a._attr) return false; + + // check that attribute belongs to *this + xml_attribute_struct* attr = a._attr; + + while (attr->prev_attribute_c->next_attribute) attr = attr->prev_attribute_c; + + if (attr != _root->first_attribute) return false; + + if (a._attr->next_attribute) a._attr->next_attribute->prev_attribute_c = a._attr->prev_attribute_c; + else if (_root->first_attribute) _root->first_attribute->prev_attribute_c = a._attr->prev_attribute_c; + + if (a._attr->prev_attribute_c->next_attribute) a._attr->prev_attribute_c->next_attribute = a._attr->next_attribute; + else _root->first_attribute = a._attr->next_attribute; + + impl::destroy_attribute(a._attr, impl::get_allocator(_root)); + + return true; + } + + PUGI__FN bool xml_node::remove_child(const char_t* name_) + { + return remove_child(child(name_)); + } + + PUGI__FN bool xml_node::remove_child(const xml_node& n) + { + if (!_root || !n._root || n._root->parent != _root) return false; + + if (n._root->next_sibling) n._root->next_sibling->prev_sibling_c = n._root->prev_sibling_c; + else if (_root->first_child) _root->first_child->prev_sibling_c = n._root->prev_sibling_c; + + if (n._root->prev_sibling_c->next_sibling) n._root->prev_sibling_c->next_sibling = n._root->next_sibling; + else _root->first_child = n._root->next_sibling; + + impl::destroy_node(n._root, impl::get_allocator(_root)); + + return true; + } + + PUGI__FN xml_node xml_node::find_child_by_attribute(const char_t* name_, const char_t* attr_name, const char_t* attr_value) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if (i->name && impl::strequal(name_, i->name)) + { + for (xml_attribute_struct* a = i->first_attribute; a; a = a->next_attribute) + if (impl::strequal(attr_name, a->name) && impl::strequal(attr_value, a->value)) + return xml_node(i); + } + + return xml_node(); + } + + PUGI__FN xml_node xml_node::find_child_by_attribute(const char_t* attr_name, const char_t* attr_value) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + for (xml_attribute_struct* a = i->first_attribute; a; a = a->next_attribute) + if (impl::strequal(attr_name, a->name) && impl::strequal(attr_value, a->value)) + return xml_node(i); + + return xml_node(); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN string_t xml_node::path(char_t delimiter) const + { + xml_node cursor = *this; // Make a copy. + + string_t result = cursor.name(); + + while (cursor.parent()) + { + cursor = cursor.parent(); + + string_t temp = cursor.name(); + temp += delimiter; + temp += result; + result.swap(temp); + } + + return result; + } +#endif + + PUGI__FN xml_node xml_node::first_element_by_path(const char_t* path_, char_t delimiter) const + { + xml_node found = *this; // Current search context. + + if (!_root || !path_ || !path_[0]) return found; + + if (path_[0] == delimiter) + { + // Absolute path; e.g. '/foo/bar' + found = found.root(); + ++path_; + } + + const char_t* path_segment = path_; + + while (*path_segment == delimiter) ++path_segment; + + const char_t* path_segment_end = path_segment; + + while (*path_segment_end && *path_segment_end != delimiter) ++path_segment_end; + + if (path_segment == path_segment_end) return found; + + const char_t* next_segment = path_segment_end; + + while (*next_segment == delimiter) ++next_segment; + + if (*path_segment == '.' && path_segment + 1 == path_segment_end) + return found.first_element_by_path(next_segment, delimiter); + else if (*path_segment == '.' && *(path_segment+1) == '.' && path_segment + 2 == path_segment_end) + return found.parent().first_element_by_path(next_segment, delimiter); + else + { + for (xml_node_struct* j = found._root->first_child; j; j = j->next_sibling) + { + if (j->name && impl::strequalrange(j->name, path_segment, static_cast(path_segment_end - path_segment))) + { + xml_node subsearch = xml_node(j).first_element_by_path(next_segment, delimiter); + + if (subsearch) return subsearch; + } + } + + return xml_node(); + } + } + + PUGI__FN bool xml_node::traverse(xml_tree_walker& walker) + { + walker._depth = -1; + + xml_node arg_begin = *this; + if (!walker.begin(arg_begin)) return false; + + xml_node cur = first_child(); + + if (cur) + { + ++walker._depth; + + do + { + xml_node arg_for_each = cur; + if (!walker.for_each(arg_for_each)) + return false; + + if (cur.first_child()) + { + ++walker._depth; + cur = cur.first_child(); + } + else if (cur.next_sibling()) + cur = cur.next_sibling(); + else + { + // Borland C++ workaround + while (!cur.next_sibling() && cur != *this && !cur.parent().empty()) + { + --walker._depth; + cur = cur.parent(); + } + + if (cur != *this) + cur = cur.next_sibling(); + } + } + while (cur && cur != *this); + } + + assert(walker._depth == -1); + + xml_node arg_end = *this; + return walker.end(arg_end); + } + + PUGI__FN size_t xml_node::hash_value() const + { + return static_cast(reinterpret_cast(_root) / sizeof(xml_node_struct)); + } + + PUGI__FN xml_node_struct* xml_node::internal_object() const + { + return _root; + } + + PUGI__FN void xml_node::print(xml_writer& writer, const char_t* indent, unsigned int flags, xml_encoding encoding, unsigned int depth) const + { + if (!_root) return; + + impl::xml_buffered_writer buffered_writer(writer, encoding); + + impl::node_output(buffered_writer, *this, indent, flags, depth); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN void xml_node::print(std::basic_ostream >& stream, const char_t* indent, unsigned int flags, xml_encoding encoding, unsigned int depth) const + { + xml_writer_stream writer(stream); + + print(writer, indent, flags, encoding, depth); + } + + PUGI__FN void xml_node::print(std::basic_ostream >& stream, const char_t* indent, unsigned int flags, unsigned int depth) const + { + xml_writer_stream writer(stream); + + print(writer, indent, flags, encoding_wchar, depth); + } +#endif + + PUGI__FN ptrdiff_t xml_node::offset_debug() const + { + xml_node_struct* r = root()._root; + + if (!r) return -1; + + const char_t* buffer = static_cast(r)->buffer; + + if (!buffer) return -1; + + switch (type()) + { + case node_document: + return 0; + + case node_element: + case node_declaration: + case node_pi: + return (_root->header & impl::xml_memory_page_name_allocated_mask) ? -1 : _root->name - buffer; + + case node_pcdata: + case node_cdata: + case node_comment: + case node_doctype: + return (_root->header & impl::xml_memory_page_value_allocated_mask) ? -1 : _root->value - buffer; + + default: + return -1; + } + } + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xml_node& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xml_node& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN xml_text::xml_text(xml_node_struct* root): _root(root) + { + } + + PUGI__FN xml_node_struct* xml_text::_data() const + { + if (!_root || impl::is_text_node(_root)) return _root; + + for (xml_node_struct* node = _root->first_child; node; node = node->next_sibling) + if (impl::is_text_node(node)) + return node; + + return 0; + } + + PUGI__FN xml_node_struct* xml_text::_data_new() + { + xml_node_struct* d = _data(); + if (d) return d; + + return xml_node(_root).append_child(node_pcdata).internal_object(); + } + + PUGI__FN xml_text::xml_text(): _root(0) + { + } + + PUGI__FN static void unspecified_bool_xml_text(xml_text***) + { + } + + PUGI__FN xml_text::operator xml_text::unspecified_bool_type() const + { + return _data() ? unspecified_bool_xml_text : 0; + } + + PUGI__FN bool xml_text::operator!() const + { + return !_data(); + } + + PUGI__FN bool xml_text::empty() const + { + return _data() == 0; + } + + PUGI__FN const char_t* xml_text::get() const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? d->value : PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* xml_text::as_string(const char_t* def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? d->value : def; + } + + PUGI__FN int xml_text::as_int(int def) const + { + xml_node_struct* d = _data(); + + return impl::get_value_int(d ? d->value : 0, def); + } + + PUGI__FN unsigned int xml_text::as_uint(unsigned int def) const + { + xml_node_struct* d = _data(); + + return impl::get_value_uint(d ? d->value : 0, def); + } + + PUGI__FN double xml_text::as_double(double def) const + { + xml_node_struct* d = _data(); + + return impl::get_value_double(d ? d->value : 0, def); + } + + PUGI__FN float xml_text::as_float(float def) const + { + xml_node_struct* d = _data(); + + return impl::get_value_float(d ? d->value : 0, def); + } + + PUGI__FN bool xml_text::as_bool(bool def) const + { + xml_node_struct* d = _data(); + + return impl::get_value_bool(d ? d->value : 0, def); + } + + PUGI__FN bool xml_text::set(const char_t* rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::strcpy_insitu(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs) : false; + } + + PUGI__FN bool xml_text::set(int rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs) : false; + } + + PUGI__FN bool xml_text::set(unsigned int rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs) : false; + } + + PUGI__FN bool xml_text::set(double rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs) : false; + } + + PUGI__FN bool xml_text::set(bool rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs) : false; + } + + PUGI__FN xml_text& xml_text::operator=(const char_t* rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(int rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(unsigned int rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(double rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(bool rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_node xml_text::data() const + { + return xml_node(_data()); + } + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xml_text& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xml_text& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN xml_node_iterator::xml_node_iterator() + { + } + + PUGI__FN xml_node_iterator::xml_node_iterator(const xml_node& node): _wrap(node), _parent(node.parent()) + { + } + + PUGI__FN xml_node_iterator::xml_node_iterator(xml_node_struct* ref, xml_node_struct* parent): _wrap(ref), _parent(parent) + { + } + + PUGI__FN bool xml_node_iterator::operator==(const xml_node_iterator& rhs) const + { + return _wrap._root == rhs._wrap._root && _parent._root == rhs._parent._root; + } + + PUGI__FN bool xml_node_iterator::operator!=(const xml_node_iterator& rhs) const + { + return _wrap._root != rhs._wrap._root || _parent._root != rhs._parent._root; + } + + PUGI__FN xml_node& xml_node_iterator::operator*() const + { + assert(_wrap._root); + return _wrap; + } + + PUGI__FN xml_node* xml_node_iterator::operator->() const + { + assert(_wrap._root); + return const_cast(&_wrap); // BCC32 workaround + } + + PUGI__FN const xml_node_iterator& xml_node_iterator::operator++() + { + assert(_wrap._root); + _wrap._root = _wrap._root->next_sibling; + return *this; + } + + PUGI__FN xml_node_iterator xml_node_iterator::operator++(int) + { + xml_node_iterator temp = *this; + ++*this; + return temp; + } + + PUGI__FN const xml_node_iterator& xml_node_iterator::operator--() + { + _wrap = _wrap._root ? _wrap.previous_sibling() : _parent.last_child(); + return *this; + } + + PUGI__FN xml_node_iterator xml_node_iterator::operator--(int) + { + xml_node_iterator temp = *this; + --*this; + return temp; + } + + PUGI__FN xml_attribute_iterator::xml_attribute_iterator() + { + } + + PUGI__FN xml_attribute_iterator::xml_attribute_iterator(const xml_attribute& attr, const xml_node& parent): _wrap(attr), _parent(parent) + { + } + + PUGI__FN xml_attribute_iterator::xml_attribute_iterator(xml_attribute_struct* ref, xml_node_struct* parent): _wrap(ref), _parent(parent) + { + } + + PUGI__FN bool xml_attribute_iterator::operator==(const xml_attribute_iterator& rhs) const + { + return _wrap._attr == rhs._wrap._attr && _parent._root == rhs._parent._root; + } + + PUGI__FN bool xml_attribute_iterator::operator!=(const xml_attribute_iterator& rhs) const + { + return _wrap._attr != rhs._wrap._attr || _parent._root != rhs._parent._root; + } + + PUGI__FN xml_attribute& xml_attribute_iterator::operator*() const + { + assert(_wrap._attr); + return _wrap; + } + + PUGI__FN xml_attribute* xml_attribute_iterator::operator->() const + { + assert(_wrap._attr); + return const_cast(&_wrap); // BCC32 workaround + } + + PUGI__FN const xml_attribute_iterator& xml_attribute_iterator::operator++() + { + assert(_wrap._attr); + _wrap._attr = _wrap._attr->next_attribute; + return *this; + } + + PUGI__FN xml_attribute_iterator xml_attribute_iterator::operator++(int) + { + xml_attribute_iterator temp = *this; + ++*this; + return temp; + } + + PUGI__FN const xml_attribute_iterator& xml_attribute_iterator::operator--() + { + _wrap = _wrap._attr ? _wrap.previous_attribute() : _parent.last_attribute(); + return *this; + } + + PUGI__FN xml_attribute_iterator xml_attribute_iterator::operator--(int) + { + xml_attribute_iterator temp = *this; + --*this; + return temp; + } + + PUGI__FN xml_named_node_iterator::xml_named_node_iterator(): _name(0) + { + } + + PUGI__FN xml_named_node_iterator::xml_named_node_iterator(const xml_node& node, const char_t* name): _node(node), _name(name) + { + } + + PUGI__FN bool xml_named_node_iterator::operator==(const xml_named_node_iterator& rhs) const + { + return _node == rhs._node; + } + + PUGI__FN bool xml_named_node_iterator::operator!=(const xml_named_node_iterator& rhs) const + { + return _node != rhs._node; + } + + PUGI__FN xml_node& xml_named_node_iterator::operator*() const + { + assert(_node._root); + return _node; + } + + PUGI__FN xml_node* xml_named_node_iterator::operator->() const + { + assert(_node._root); + return const_cast(&_node); // BCC32 workaround + } + + PUGI__FN const xml_named_node_iterator& xml_named_node_iterator::operator++() + { + assert(_node._root); + _node = _node.next_sibling(_name); + return *this; + } + + PUGI__FN xml_named_node_iterator xml_named_node_iterator::operator++(int) + { + xml_named_node_iterator temp = *this; + ++*this; + return temp; + } + + PUGI__FN xml_parse_result::xml_parse_result(): status(status_internal_error), offset(0), encoding(encoding_auto) + { + } + + PUGI__FN xml_parse_result::operator bool() const + { + return status == status_ok; + } + + PUGI__FN const char* xml_parse_result::description() const + { + switch (status) + { + case status_ok: return "No error"; + + case status_file_not_found: return "File was not found"; + case status_io_error: return "Error reading from file/stream"; + case status_out_of_memory: return "Could not allocate memory"; + case status_internal_error: return "Internal error occurred"; + + case status_unrecognized_tag: return "Could not determine tag type"; + + case status_bad_pi: return "Error parsing document declaration/processing instruction"; + case status_bad_comment: return "Error parsing comment"; + case status_bad_cdata: return "Error parsing CDATA section"; + case status_bad_doctype: return "Error parsing document type declaration"; + case status_bad_pcdata: return "Error parsing PCDATA section"; + case status_bad_start_element: return "Error parsing start element tag"; + case status_bad_attribute: return "Error parsing element attribute"; + case status_bad_end_element: return "Error parsing end element tag"; + case status_end_element_mismatch: return "Start-end tags mismatch"; + + default: return "Unknown error"; + } + } + + PUGI__FN xml_document::xml_document(): _buffer(0) + { + create(); + } + + PUGI__FN xml_document::~xml_document() + { + destroy(); + } + + PUGI__FN void xml_document::reset() + { + destroy(); + create(); + } + + PUGI__FN void xml_document::reset(const xml_document& proto) + { + reset(); + + for (xml_node cur = proto.first_child(); cur; cur = cur.next_sibling()) + append_copy(cur); + } + + PUGI__FN void xml_document::create() + { + // initialize sentinel page + PUGI__STATIC_ASSERT(offsetof(impl::xml_memory_page, data) + sizeof(impl::xml_document_struct) + impl::xml_memory_page_alignment <= sizeof(_memory)); + + // align upwards to page boundary + void* page_memory = reinterpret_cast((reinterpret_cast(_memory) + (impl::xml_memory_page_alignment - 1)) & ~(impl::xml_memory_page_alignment - 1)); + + // prepare page structure + impl::xml_memory_page* page = impl::xml_memory_page::construct(page_memory); + + page->busy_size = impl::xml_memory_page_size; + + // allocate new root + _root = new (page->data) impl::xml_document_struct(page); + _root->prev_sibling_c = _root; + + // setup sentinel page + page->allocator = static_cast(_root); + } + + PUGI__FN void xml_document::destroy() + { + // destroy static storage + if (_buffer) + { + impl::xml_memory::deallocate(_buffer); + _buffer = 0; + } + + // destroy dynamic storage, leave sentinel page (it's in static memory) + if (_root) + { + impl::xml_memory_page* root_page = reinterpret_cast(_root->header & impl::xml_memory_page_pointer_mask); + assert(root_page && !root_page->prev && !root_page->memory); + + // destroy all pages + for (impl::xml_memory_page* page = root_page->next; page; ) + { + impl::xml_memory_page* next = page->next; + + impl::xml_allocator::deallocate_page(page); + + page = next; + } + + // cleanup root page + root_page->allocator = 0; + root_page->next = 0; + root_page->busy_size = root_page->freed_size = 0; + + _root = 0; + } + } + +#ifndef PUGIXML_NO_STL + PUGI__FN xml_parse_result xml_document::load(std::basic_istream >& stream, unsigned int options, xml_encoding encoding) + { + reset(); + + return impl::load_stream_impl(*this, stream, options, encoding); + } + + PUGI__FN xml_parse_result xml_document::load(std::basic_istream >& stream, unsigned int options) + { + reset(); + + return impl::load_stream_impl(*this, stream, options, encoding_wchar); + } +#endif + + PUGI__FN xml_parse_result xml_document::load(const char_t* contents, unsigned int options) + { + // Force native encoding (skip autodetection) + #ifdef PUGIXML_WCHAR_MODE + xml_encoding encoding = encoding_wchar; + #else + xml_encoding encoding = encoding_utf8; + #endif + + return load_buffer(contents, impl::strlength(contents) * sizeof(char_t), options, encoding); + } + + PUGI__FN xml_parse_result xml_document::load_file(const char* path_, unsigned int options, xml_encoding encoding) + { + reset(); + + FILE* file = fopen(path_, "rb"); + + return impl::load_file_impl(*this, file, options, encoding); + } + + PUGI__FN xml_parse_result xml_document::load_file(const wchar_t* path_, unsigned int options, xml_encoding encoding) + { + reset(); + + FILE* file = impl::open_file_wide(path_, L"rb"); + + return impl::load_file_impl(*this, file, options, encoding); + } + + PUGI__FN xml_parse_result xml_document::load_buffer_impl(void* contents, size_t size, unsigned int options, xml_encoding encoding, bool is_mutable, bool own) + { + reset(); + + // check input buffer + assert(contents || size == 0); + + // get actual encoding + xml_encoding buffer_encoding = impl::get_buffer_encoding(encoding, contents, size); + + // get private buffer + char_t* buffer = 0; + size_t length = 0; + + if (!impl::convert_buffer(buffer, length, buffer_encoding, contents, size, is_mutable)) return impl::make_parse_result(status_out_of_memory); + + // delete original buffer if we performed a conversion + if (own && buffer != contents && contents) impl::xml_memory::deallocate(contents); + + // parse + xml_parse_result res = impl::xml_parser::parse(buffer, length, _root, options); + + // remember encoding + res.encoding = buffer_encoding; + + // grab onto buffer if it's our buffer, user is responsible for deallocating contens himself + if (own || buffer != contents) _buffer = buffer; + + return res; + } + + PUGI__FN xml_parse_result xml_document::load_buffer(const void* contents, size_t size, unsigned int options, xml_encoding encoding) + { + return load_buffer_impl(const_cast(contents), size, options, encoding, false, false); + } + + PUGI__FN xml_parse_result xml_document::load_buffer_inplace(void* contents, size_t size, unsigned int options, xml_encoding encoding) + { + return load_buffer_impl(contents, size, options, encoding, true, false); + } + + PUGI__FN xml_parse_result xml_document::load_buffer_inplace_own(void* contents, size_t size, unsigned int options, xml_encoding encoding) + { + return load_buffer_impl(contents, size, options, encoding, true, true); + } + + PUGI__FN void xml_document::save(xml_writer& writer, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + impl::xml_buffered_writer buffered_writer(writer, encoding); + + if ((flags & format_write_bom) && encoding != encoding_latin1) + { + // BOM always represents the codepoint U+FEFF, so just write it in native encoding + #ifdef PUGIXML_WCHAR_MODE + unsigned int bom = 0xfeff; + buffered_writer.write(static_cast(bom)); + #else + buffered_writer.write('\xef', '\xbb', '\xbf'); + #endif + } + + if (!(flags & format_no_declaration) && !impl::has_declaration(*this)) + { + buffered_writer.write(PUGIXML_TEXT("'); + if (!(flags & format_raw)) buffered_writer.write('\n'); + } + + impl::node_output(buffered_writer, *this, indent, flags, 0); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN void xml_document::save(std::basic_ostream >& stream, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + xml_writer_stream writer(stream); + + save(writer, indent, flags, encoding); + } + + PUGI__FN void xml_document::save(std::basic_ostream >& stream, const char_t* indent, unsigned int flags) const + { + xml_writer_stream writer(stream); + + save(writer, indent, flags, encoding_wchar); + } +#endif + + PUGI__FN bool xml_document::save_file(const char* path_, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + FILE* file = fopen(path_, (flags & format_save_file_text) ? "w" : "wb"); + return impl::save_file_impl(*this, file, indent, flags, encoding); + } + + PUGI__FN bool xml_document::save_file(const wchar_t* path_, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + FILE* file = impl::open_file_wide(path_, (flags & format_save_file_text) ? L"w" : L"wb"); + return impl::save_file_impl(*this, file, indent, flags, encoding); + } + + PUGI__FN xml_node xml_document::document_element() const + { + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if ((i->header & impl::xml_memory_page_type_mask) + 1 == node_element) + return xml_node(i); + + return xml_node(); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN std::string PUGIXML_FUNCTION as_utf8(const wchar_t* str) + { + assert(str); + + return impl::as_utf8_impl(str, wcslen(str)); + } + + PUGI__FN std::string PUGIXML_FUNCTION as_utf8(const std::basic_string& str) + { + return impl::as_utf8_impl(str.c_str(), str.size()); + } + + PUGI__FN std::basic_string PUGIXML_FUNCTION as_wide(const char* str) + { + assert(str); + + return impl::as_wide_impl(str, strlen(str)); + } + + PUGI__FN std::basic_string PUGIXML_FUNCTION as_wide(const std::string& str) + { + return impl::as_wide_impl(str.c_str(), str.size()); + } +#endif + + PUGI__FN void PUGIXML_FUNCTION set_memory_management_functions(allocation_function allocate, deallocation_function deallocate) + { + impl::xml_memory::allocate = allocate; + impl::xml_memory::deallocate = deallocate; + } + + PUGI__FN allocation_function PUGIXML_FUNCTION get_memory_allocation_function() + { + return impl::xml_memory::allocate; + } + + PUGI__FN deallocation_function PUGIXML_FUNCTION get_memory_deallocation_function() + { + return impl::xml_memory::deallocate; + } +} + +#if !defined(PUGIXML_NO_STL) && (defined(_MSC_VER) || defined(__ICC)) +namespace std +{ + // Workarounds for (non-standard) iterator category detection for older versions (MSVC7/IC8 and earlier) + PUGI__FN std::bidirectional_iterator_tag _Iter_cat(const pugi::xml_node_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::bidirectional_iterator_tag _Iter_cat(const pugi::xml_attribute_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::forward_iterator_tag _Iter_cat(const pugi::xml_named_node_iterator&) + { + return std::forward_iterator_tag(); + } +} +#endif + +#if !defined(PUGIXML_NO_STL) && defined(__SUNPRO_CC) +namespace std +{ + // Workarounds for (non-standard) iterator category detection + PUGI__FN std::bidirectional_iterator_tag __iterator_category(const pugi::xml_node_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::bidirectional_iterator_tag __iterator_category(const pugi::xml_attribute_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::forward_iterator_tag __iterator_category(const pugi::xml_named_node_iterator&) + { + return std::forward_iterator_tag(); + } +} +#endif + +#ifndef PUGIXML_NO_XPATH + +// STL replacements +PUGI__NS_BEGIN + struct equal_to + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs == rhs; + } + }; + + struct not_equal_to + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs != rhs; + } + }; + + struct less + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs < rhs; + } + }; + + struct less_equal + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs <= rhs; + } + }; + + template void swap(T& lhs, T& rhs) + { + T temp = lhs; + lhs = rhs; + rhs = temp; + } + + template I min_element(I begin, I end, const Pred& pred) + { + I result = begin; + + for (I it = begin + 1; it != end; ++it) + if (pred(*it, *result)) + result = it; + + return result; + } + + template void reverse(I begin, I end) + { + while (begin + 1 < end) swap(*begin++, *--end); + } + + template I unique(I begin, I end) + { + // fast skip head + while (begin + 1 < end && *begin != *(begin + 1)) begin++; + + if (begin == end) return begin; + + // last written element + I write = begin++; + + // merge unique elements + while (begin != end) + { + if (*begin != *write) + *++write = *begin++; + else + begin++; + } + + // past-the-end (write points to live element) + return write + 1; + } + + template void copy_backwards(I begin, I end, I target) + { + while (begin != end) *--target = *--end; + } + + template void insertion_sort(I begin, I end, const Pred& pred, T*) + { + assert(begin != end); + + for (I it = begin + 1; it != end; ++it) + { + T val = *it; + + if (pred(val, *begin)) + { + // move to front + copy_backwards(begin, it, it + 1); + *begin = val; + } + else + { + I hole = it; + + // move hole backwards + while (pred(val, *(hole - 1))) + { + *hole = *(hole - 1); + hole--; + } + + // fill hole with element + *hole = val; + } + } + } + + // std variant for elements with == + template void partition(I begin, I middle, I end, const Pred& pred, I* out_eqbeg, I* out_eqend) + { + I eqbeg = middle, eqend = middle + 1; + + // expand equal range + while (eqbeg != begin && *(eqbeg - 1) == *eqbeg) --eqbeg; + while (eqend != end && *eqend == *eqbeg) ++eqend; + + // process outer elements + I ltend = eqbeg, gtbeg = eqend; + + for (;;) + { + // find the element from the right side that belongs to the left one + for (; gtbeg != end; ++gtbeg) + if (!pred(*eqbeg, *gtbeg)) + { + if (*gtbeg == *eqbeg) swap(*gtbeg, *eqend++); + else break; + } + + // find the element from the left side that belongs to the right one + for (; ltend != begin; --ltend) + if (!pred(*(ltend - 1), *eqbeg)) + { + if (*eqbeg == *(ltend - 1)) swap(*(ltend - 1), *--eqbeg); + else break; + } + + // scanned all elements + if (gtbeg == end && ltend == begin) + { + *out_eqbeg = eqbeg; + *out_eqend = eqend; + return; + } + + // make room for elements by moving equal area + if (gtbeg == end) + { + if (--ltend != --eqbeg) swap(*ltend, *eqbeg); + swap(*eqbeg, *--eqend); + } + else if (ltend == begin) + { + if (eqend != gtbeg) swap(*eqbeg, *eqend); + ++eqend; + swap(*gtbeg++, *eqbeg++); + } + else swap(*gtbeg++, *--ltend); + } + } + + template void median3(I first, I middle, I last, const Pred& pred) + { + if (pred(*middle, *first)) swap(*middle, *first); + if (pred(*last, *middle)) swap(*last, *middle); + if (pred(*middle, *first)) swap(*middle, *first); + } + + template void median(I first, I middle, I last, const Pred& pred) + { + if (last - first <= 40) + { + // median of three for small chunks + median3(first, middle, last, pred); + } + else + { + // median of nine + size_t step = (last - first + 1) / 8; + + median3(first, first + step, first + 2 * step, pred); + median3(middle - step, middle, middle + step, pred); + median3(last - 2 * step, last - step, last, pred); + median3(first + step, middle, last - step, pred); + } + } + + template void sort(I begin, I end, const Pred& pred) + { + // sort large chunks + while (end - begin > 32) + { + // find median element + I middle = begin + (end - begin) / 2; + median(begin, middle, end - 1, pred); + + // partition in three chunks (< = >) + I eqbeg, eqend; + partition(begin, middle, end, pred, &eqbeg, &eqend); + + // loop on larger half + if (eqbeg - begin > end - eqend) + { + sort(eqend, end, pred); + end = eqbeg; + } + else + { + sort(begin, eqbeg, pred); + begin = eqend; + } + } + + // insertion sort small chunk + if (begin != end) insertion_sort(begin, end, pred, &*begin); + } +PUGI__NS_END + +// Allocator used for AST and evaluation stacks +PUGI__NS_BEGIN + struct xpath_memory_block + { + xpath_memory_block* next; + + char data[ + #ifdef PUGIXML_MEMORY_XPATH_PAGE_SIZE + PUGIXML_MEMORY_XPATH_PAGE_SIZE + #else + 4096 + #endif + ]; + }; + + class xpath_allocator + { + xpath_memory_block* _root; + size_t _root_size; + + public: + #ifdef PUGIXML_NO_EXCEPTIONS + jmp_buf* error_handler; + #endif + + xpath_allocator(xpath_memory_block* root, size_t root_size = 0): _root(root), _root_size(root_size) + { + #ifdef PUGIXML_NO_EXCEPTIONS + error_handler = 0; + #endif + } + + void* allocate_nothrow(size_t size) + { + const size_t block_capacity = sizeof(_root->data); + + // align size so that we're able to store pointers in subsequent blocks + size = (size + sizeof(void*) - 1) & ~(sizeof(void*) - 1); + + if (_root_size + size <= block_capacity) + { + void* buf = _root->data + _root_size; + _root_size += size; + return buf; + } + else + { + size_t block_data_size = (size > block_capacity) ? size : block_capacity; + size_t block_size = block_data_size + offsetof(xpath_memory_block, data); + + xpath_memory_block* block = static_cast(xml_memory::allocate(block_size)); + if (!block) return 0; + + block->next = _root; + + _root = block; + _root_size = size; + + return block->data; + } + } + + void* allocate(size_t size) + { + void* result = allocate_nothrow(size); + + if (!result) + { + #ifdef PUGIXML_NO_EXCEPTIONS + assert(error_handler); + longjmp(*error_handler, 1); + #else + throw std::bad_alloc(); + #endif + } + + return result; + } + + void* reallocate(void* ptr, size_t old_size, size_t new_size) + { + // align size so that we're able to store pointers in subsequent blocks + old_size = (old_size + sizeof(void*) - 1) & ~(sizeof(void*) - 1); + new_size = (new_size + sizeof(void*) - 1) & ~(sizeof(void*) - 1); + + // we can only reallocate the last object + assert(ptr == 0 || static_cast(ptr) + old_size == _root->data + _root_size); + + // adjust root size so that we have not allocated the object at all + bool only_object = (_root_size == old_size); + + if (ptr) _root_size -= old_size; + + // allocate a new version (this will obviously reuse the memory if possible) + void* result = allocate(new_size); + assert(result); + + // we have a new block + if (result != ptr && ptr) + { + // copy old data + assert(new_size > old_size); + memcpy(result, ptr, old_size); + + // free the previous page if it had no other objects + if (only_object) + { + assert(_root->data == result); + assert(_root->next); + + xpath_memory_block* next = _root->next->next; + + if (next) + { + // deallocate the whole page, unless it was the first one + xml_memory::deallocate(_root->next); + _root->next = next; + } + } + } + + return result; + } + + void revert(const xpath_allocator& state) + { + // free all new pages + xpath_memory_block* cur = _root; + + while (cur != state._root) + { + xpath_memory_block* next = cur->next; + + xml_memory::deallocate(cur); + + cur = next; + } + + // restore state + _root = state._root; + _root_size = state._root_size; + } + + void release() + { + xpath_memory_block* cur = _root; + assert(cur); + + while (cur->next) + { + xpath_memory_block* next = cur->next; + + xml_memory::deallocate(cur); + + cur = next; + } + } + }; + + struct xpath_allocator_capture + { + xpath_allocator_capture(xpath_allocator* alloc): _target(alloc), _state(*alloc) + { + } + + ~xpath_allocator_capture() + { + _target->revert(_state); + } + + xpath_allocator* _target; + xpath_allocator _state; + }; + + struct xpath_stack + { + xpath_allocator* result; + xpath_allocator* temp; + }; + + struct xpath_stack_data + { + xpath_memory_block blocks[2]; + xpath_allocator result; + xpath_allocator temp; + xpath_stack stack; + + #ifdef PUGIXML_NO_EXCEPTIONS + jmp_buf error_handler; + #endif + + xpath_stack_data(): result(blocks + 0), temp(blocks + 1) + { + blocks[0].next = blocks[1].next = 0; + + stack.result = &result; + stack.temp = &temp; + + #ifdef PUGIXML_NO_EXCEPTIONS + result.error_handler = temp.error_handler = &error_handler; + #endif + } + + ~xpath_stack_data() + { + result.release(); + temp.release(); + } + }; +PUGI__NS_END + +// String class +PUGI__NS_BEGIN + class xpath_string + { + const char_t* _buffer; + bool _uses_heap; + + static char_t* duplicate_string(const char_t* string, size_t length, xpath_allocator* alloc) + { + char_t* result = static_cast(alloc->allocate((length + 1) * sizeof(char_t))); + assert(result); + + memcpy(result, string, length * sizeof(char_t)); + result[length] = 0; + + return result; + } + + static char_t* duplicate_string(const char_t* string, xpath_allocator* alloc) + { + return duplicate_string(string, strlength(string), alloc); + } + + public: + xpath_string(): _buffer(PUGIXML_TEXT("")), _uses_heap(false) + { + } + + explicit xpath_string(const char_t* str, xpath_allocator* alloc) + { + bool empty_ = (*str == 0); + + _buffer = empty_ ? PUGIXML_TEXT("") : duplicate_string(str, alloc); + _uses_heap = !empty_; + } + + explicit xpath_string(const char_t* str, bool use_heap): _buffer(str), _uses_heap(use_heap) + { + } + + xpath_string(const char_t* begin, const char_t* end, xpath_allocator* alloc) + { + assert(begin <= end); + + bool empty_ = (begin == end); + + _buffer = empty_ ? PUGIXML_TEXT("") : duplicate_string(begin, static_cast(end - begin), alloc); + _uses_heap = !empty_; + } + + void append(const xpath_string& o, xpath_allocator* alloc) + { + // skip empty sources + if (!*o._buffer) return; + + // fast append for constant empty target and constant source + if (!*_buffer && !_uses_heap && !o._uses_heap) + { + _buffer = o._buffer; + } + else + { + // need to make heap copy + size_t target_length = strlength(_buffer); + size_t source_length = strlength(o._buffer); + size_t result_length = target_length + source_length; + + // allocate new buffer + char_t* result = static_cast(alloc->reallocate(_uses_heap ? const_cast(_buffer) : 0, (target_length + 1) * sizeof(char_t), (result_length + 1) * sizeof(char_t))); + assert(result); + + // append first string to the new buffer in case there was no reallocation + if (!_uses_heap) memcpy(result, _buffer, target_length * sizeof(char_t)); + + // append second string to the new buffer + memcpy(result + target_length, o._buffer, source_length * sizeof(char_t)); + result[result_length] = 0; + + // finalize + _buffer = result; + _uses_heap = true; + } + } + + const char_t* c_str() const + { + return _buffer; + } + + size_t length() const + { + return strlength(_buffer); + } + + char_t* data(xpath_allocator* alloc) + { + // make private heap copy + if (!_uses_heap) + { + _buffer = duplicate_string(_buffer, alloc); + _uses_heap = true; + } + + return const_cast(_buffer); + } + + bool empty() const + { + return *_buffer == 0; + } + + bool operator==(const xpath_string& o) const + { + return strequal(_buffer, o._buffer); + } + + bool operator!=(const xpath_string& o) const + { + return !strequal(_buffer, o._buffer); + } + + bool uses_heap() const + { + return _uses_heap; + } + }; + + PUGI__FN xpath_string xpath_string_const(const char_t* str) + { + return xpath_string(str, false); + } +PUGI__NS_END + +PUGI__NS_BEGIN + PUGI__FN bool starts_with(const char_t* string, const char_t* pattern) + { + while (*pattern && *string == *pattern) + { + string++; + pattern++; + } + + return *pattern == 0; + } + + PUGI__FN const char_t* find_char(const char_t* s, char_t c) + { + #ifdef PUGIXML_WCHAR_MODE + return wcschr(s, c); + #else + return strchr(s, c); + #endif + } + + PUGI__FN const char_t* find_substring(const char_t* s, const char_t* p) + { + #ifdef PUGIXML_WCHAR_MODE + // MSVC6 wcsstr bug workaround (if s is empty it always returns 0) + return (*p == 0) ? s : wcsstr(s, p); + #else + return strstr(s, p); + #endif + } + + // Converts symbol to lower case, if it is an ASCII one + PUGI__FN char_t tolower_ascii(char_t ch) + { + return static_cast(ch - 'A') < 26 ? static_cast(ch | ' ') : ch; + } + + PUGI__FN xpath_string string_value(const xpath_node& na, xpath_allocator* alloc) + { + if (na.attribute()) + return xpath_string_const(na.attribute().value()); + else + { + const xml_node& n = na.node(); + + switch (n.type()) + { + case node_pcdata: + case node_cdata: + case node_comment: + case node_pi: + return xpath_string_const(n.value()); + + case node_document: + case node_element: + { + xpath_string result; + + xml_node cur = n.first_child(); + + while (cur && cur != n) + { + if (cur.type() == node_pcdata || cur.type() == node_cdata) + result.append(xpath_string_const(cur.value()), alloc); + + if (cur.first_child()) + cur = cur.first_child(); + else if (cur.next_sibling()) + cur = cur.next_sibling(); + else + { + while (!cur.next_sibling() && cur != n) + cur = cur.parent(); + + if (cur != n) cur = cur.next_sibling(); + } + } + + return result; + } + + default: + return xpath_string(); + } + } + } + + PUGI__FN unsigned int node_height(xml_node n) + { + unsigned int result = 0; + + while (n) + { + ++result; + n = n.parent(); + } + + return result; + } + + PUGI__FN bool node_is_before(xml_node ln, unsigned int lh, xml_node rn, unsigned int rh) + { + // normalize heights + for (unsigned int i = rh; i < lh; i++) ln = ln.parent(); + for (unsigned int j = lh; j < rh; j++) rn = rn.parent(); + + // one node is the ancestor of the other + if (ln == rn) return lh < rh; + + // find common ancestor + while (ln.parent() != rn.parent()) + { + ln = ln.parent(); + rn = rn.parent(); + } + + // there is no common ancestor (the shared parent is null), nodes are from different documents + if (!ln.parent()) return ln < rn; + + // determine sibling order + for (; ln; ln = ln.next_sibling()) + if (ln == rn) + return true; + + return false; + } + + PUGI__FN bool node_is_ancestor(xml_node parent, xml_node node) + { + while (node && node != parent) node = node.parent(); + + return parent && node == parent; + } + + PUGI__FN const void* document_order(const xpath_node& xnode) + { + xml_node_struct* node = xnode.node().internal_object(); + + if (node) + { + if (node->name && (node->header & xml_memory_page_name_allocated_mask) == 0) return node->name; + if (node->value && (node->header & xml_memory_page_value_allocated_mask) == 0) return node->value; + return 0; + } + + xml_attribute_struct* attr = xnode.attribute().internal_object(); + + if (attr) + { + if ((attr->header & xml_memory_page_name_allocated_mask) == 0) return attr->name; + if ((attr->header & xml_memory_page_value_allocated_mask) == 0) return attr->value; + return 0; + } + + return 0; + } + + struct document_order_comparator + { + bool operator()(const xpath_node& lhs, const xpath_node& rhs) const + { + // optimized document order based check + const void* lo = document_order(lhs); + const void* ro = document_order(rhs); + + if (lo && ro) return lo < ro; + + // slow comparison + xml_node ln = lhs.node(), rn = rhs.node(); + + // compare attributes + if (lhs.attribute() && rhs.attribute()) + { + // shared parent + if (lhs.parent() == rhs.parent()) + { + // determine sibling order + for (xml_attribute a = lhs.attribute(); a; a = a.next_attribute()) + if (a == rhs.attribute()) + return true; + + return false; + } + + // compare attribute parents + ln = lhs.parent(); + rn = rhs.parent(); + } + else if (lhs.attribute()) + { + // attributes go after the parent element + if (lhs.parent() == rhs.node()) return false; + + ln = lhs.parent(); + } + else if (rhs.attribute()) + { + // attributes go after the parent element + if (rhs.parent() == lhs.node()) return true; + + rn = rhs.parent(); + } + + if (ln == rn) return false; + + unsigned int lh = node_height(ln); + unsigned int rh = node_height(rn); + + return node_is_before(ln, lh, rn, rh); + } + }; + + struct duplicate_comparator + { + bool operator()(const xpath_node& lhs, const xpath_node& rhs) const + { + if (lhs.attribute()) return rhs.attribute() ? lhs.attribute() < rhs.attribute() : true; + else return rhs.attribute() ? false : lhs.node() < rhs.node(); + } + }; + + PUGI__FN double gen_nan() + { + #if defined(__STDC_IEC_559__) || ((FLT_RADIX - 0 == 2) && (FLT_MAX_EXP - 0 == 128) && (FLT_MANT_DIG - 0 == 24)) + union { float f; uint32_t i; } u[sizeof(float) == sizeof(uint32_t) ? 1 : -1]; + u[0].i = 0x7fc00000; + return u[0].f; + #else + // fallback + const volatile double zero = 0.0; + return zero / zero; + #endif + } + + PUGI__FN bool is_nan(double value) + { + #if defined(PUGI__MSVC_CRT_VERSION) || defined(__BORLANDC__) + return !!_isnan(value); + #elif defined(fpclassify) && defined(FP_NAN) + return fpclassify(value) == FP_NAN; + #else + // fallback + const volatile double v = value; + return v != v; + #endif + } + + PUGI__FN const char_t* convert_number_to_string_special(double value) + { + #if defined(PUGI__MSVC_CRT_VERSION) || defined(__BORLANDC__) + if (_finite(value)) return (value == 0) ? PUGIXML_TEXT("0") : 0; + if (_isnan(value)) return PUGIXML_TEXT("NaN"); + return value > 0 ? PUGIXML_TEXT("Infinity") : PUGIXML_TEXT("-Infinity"); + #elif defined(fpclassify) && defined(FP_NAN) && defined(FP_INFINITE) && defined(FP_ZERO) + switch (fpclassify(value)) + { + case FP_NAN: + return PUGIXML_TEXT("NaN"); + + case FP_INFINITE: + return value > 0 ? PUGIXML_TEXT("Infinity") : PUGIXML_TEXT("-Infinity"); + + case FP_ZERO: + return PUGIXML_TEXT("0"); + + default: + return 0; + } + #else + // fallback + const volatile double v = value; + + if (v == 0) return PUGIXML_TEXT("0"); + if (v != v) return PUGIXML_TEXT("NaN"); + if (v * 2 == v) return value > 0 ? PUGIXML_TEXT("Infinity") : PUGIXML_TEXT("-Infinity"); + return 0; + #endif + } + + PUGI__FN bool convert_number_to_boolean(double value) + { + return (value != 0 && !is_nan(value)); + } + + PUGI__FN void truncate_zeros(char* begin, char* end) + { + while (begin != end && end[-1] == '0') end--; + + *end = 0; + } + + // gets mantissa digits in the form of 0.xxxxx with 0. implied and the exponent +#if defined(PUGI__MSVC_CRT_VERSION) && PUGI__MSVC_CRT_VERSION >= 1400 && !defined(_WIN32_WCE) + PUGI__FN void convert_number_to_mantissa_exponent(double value, char* buffer, size_t buffer_size, char** out_mantissa, int* out_exponent) + { + // get base values + int sign, exponent; + _ecvt_s(buffer, buffer_size, value, DBL_DIG + 1, &exponent, &sign); + + // truncate redundant zeros + truncate_zeros(buffer, buffer + strlen(buffer)); + + // fill results + *out_mantissa = buffer; + *out_exponent = exponent; + } +#else + PUGI__FN void convert_number_to_mantissa_exponent(double value, char* buffer, size_t buffer_size, char** out_mantissa, int* out_exponent) + { + // get a scientific notation value with IEEE DBL_DIG decimals + sprintf(buffer, "%.*e", DBL_DIG, value); + assert(strlen(buffer) < buffer_size); + (void)!buffer_size; + + // get the exponent (possibly negative) + char* exponent_string = strchr(buffer, 'e'); + assert(exponent_string); + + int exponent = atoi(exponent_string + 1); + + // extract mantissa string: skip sign + char* mantissa = buffer[0] == '-' ? buffer + 1 : buffer; + assert(mantissa[0] != '0' && mantissa[1] == '.'); + + // divide mantissa by 10 to eliminate integer part + mantissa[1] = mantissa[0]; + mantissa++; + exponent++; + + // remove extra mantissa digits and zero-terminate mantissa + truncate_zeros(mantissa, exponent_string); + + // fill results + *out_mantissa = mantissa; + *out_exponent = exponent; + } +#endif + + PUGI__FN xpath_string convert_number_to_string(double value, xpath_allocator* alloc) + { + // try special number conversion + const char_t* special = convert_number_to_string_special(value); + if (special) return xpath_string_const(special); + + // get mantissa + exponent form + char mantissa_buffer[64]; + + char* mantissa; + int exponent; + convert_number_to_mantissa_exponent(value, mantissa_buffer, sizeof(mantissa_buffer), &mantissa, &exponent); + + // make the number! + char_t result[512]; + char_t* s = result; + + // sign + if (value < 0) *s++ = '-'; + + // integer part + if (exponent <= 0) + { + *s++ = '0'; + } + else + { + while (exponent > 0) + { + assert(*mantissa == 0 || static_cast(*mantissa - '0') <= 9); + *s++ = *mantissa ? *mantissa++ : '0'; + exponent--; + } + } + + // fractional part + if (*mantissa) + { + // decimal point + *s++ = '.'; + + // extra zeroes from negative exponent + while (exponent < 0) + { + *s++ = '0'; + exponent++; + } + + // extra mantissa digits + while (*mantissa) + { + assert(static_cast(*mantissa - '0') <= 9); + *s++ = *mantissa++; + } + } + + // zero-terminate + assert(s < result + sizeof(result) / sizeof(result[0])); + *s = 0; + + return xpath_string(result, alloc); + } + + PUGI__FN bool check_string_to_number_format(const char_t* string) + { + // parse leading whitespace + while (PUGI__IS_CHARTYPE(*string, ct_space)) ++string; + + // parse sign + if (*string == '-') ++string; + + if (!*string) return false; + + // if there is no integer part, there should be a decimal part with at least one digit + if (!PUGI__IS_CHARTYPEX(string[0], ctx_digit) && (string[0] != '.' || !PUGI__IS_CHARTYPEX(string[1], ctx_digit))) return false; + + // parse integer part + while (PUGI__IS_CHARTYPEX(*string, ctx_digit)) ++string; + + // parse decimal part + if (*string == '.') + { + ++string; + + while (PUGI__IS_CHARTYPEX(*string, ctx_digit)) ++string; + } + + // parse trailing whitespace + while (PUGI__IS_CHARTYPE(*string, ct_space)) ++string; + + return *string == 0; + } + + PUGI__FN double convert_string_to_number(const char_t* string) + { + // check string format + if (!check_string_to_number_format(string)) return gen_nan(); + + // parse string + #ifdef PUGIXML_WCHAR_MODE + return wcstod(string, 0); + #else + return atof(string); + #endif + } + + PUGI__FN bool convert_string_to_number(const char_t* begin, const char_t* end, double* out_result) + { + char_t buffer[32]; + + size_t length = static_cast(end - begin); + char_t* scratch = buffer; + + if (length >= sizeof(buffer) / sizeof(buffer[0])) + { + // need to make dummy on-heap copy + scratch = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!scratch) return false; + } + + // copy string to zero-terminated buffer and perform conversion + memcpy(scratch, begin, length * sizeof(char_t)); + scratch[length] = 0; + + *out_result = convert_string_to_number(scratch); + + // free dummy buffer + if (scratch != buffer) xml_memory::deallocate(scratch); + + return true; + } + + PUGI__FN double round_nearest(double value) + { + return floor(value + 0.5); + } + + PUGI__FN double round_nearest_nzero(double value) + { + // same as round_nearest, but returns -0 for [-0.5, -0] + // ceil is used to differentiate between +0 and -0 (we return -0 for [-0.5, -0] and +0 for +0) + return (value >= -0.5 && value <= 0) ? ceil(value) : floor(value + 0.5); + } + + PUGI__FN const char_t* qualified_name(const xpath_node& node) + { + return node.attribute() ? node.attribute().name() : node.node().name(); + } + + PUGI__FN const char_t* local_name(const xpath_node& node) + { + const char_t* name = qualified_name(node); + const char_t* p = find_char(name, ':'); + + return p ? p + 1 : name; + } + + struct namespace_uri_predicate + { + const char_t* prefix; + size_t prefix_length; + + namespace_uri_predicate(const char_t* name) + { + const char_t* pos = find_char(name, ':'); + + prefix = pos ? name : 0; + prefix_length = pos ? static_cast(pos - name) : 0; + } + + bool operator()(const xml_attribute& a) const + { + const char_t* name = a.name(); + + if (!starts_with(name, PUGIXML_TEXT("xmlns"))) return false; + + return prefix ? name[5] == ':' && strequalrange(name + 6, prefix, prefix_length) : name[5] == 0; + } + }; + + PUGI__FN const char_t* namespace_uri(const xml_node& node) + { + namespace_uri_predicate pred = node.name(); + + xml_node p = node; + + while (p) + { + xml_attribute a = p.find_attribute(pred); + + if (a) return a.value(); + + p = p.parent(); + } + + return PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* namespace_uri(const xml_attribute& attr, const xml_node& parent) + { + namespace_uri_predicate pred = attr.name(); + + // Default namespace does not apply to attributes + if (!pred.prefix) return PUGIXML_TEXT(""); + + xml_node p = parent; + + while (p) + { + xml_attribute a = p.find_attribute(pred); + + if (a) return a.value(); + + p = p.parent(); + } + + return PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* namespace_uri(const xpath_node& node) + { + return node.attribute() ? namespace_uri(node.attribute(), node.parent()) : namespace_uri(node.node()); + } + + PUGI__FN void normalize_space(char_t* buffer) + { + char_t* write = buffer; + + for (char_t* it = buffer; *it; ) + { + char_t ch = *it++; + + if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + // replace whitespace sequence with single space + while (PUGI__IS_CHARTYPE(*it, ct_space)) it++; + + // avoid leading spaces + if (write != buffer) *write++ = ' '; + } + else *write++ = ch; + } + + // remove trailing space + if (write != buffer && PUGI__IS_CHARTYPE(write[-1], ct_space)) write--; + + // zero-terminate + *write = 0; + } + + PUGI__FN void translate(char_t* buffer, const char_t* from, const char_t* to) + { + size_t to_length = strlength(to); + + char_t* write = buffer; + + while (*buffer) + { + PUGI__DMC_VOLATILE char_t ch = *buffer++; + + const char_t* pos = find_char(from, ch); + + if (!pos) + *write++ = ch; // do not process + else if (static_cast(pos - from) < to_length) + *write++ = to[pos - from]; // replace + } + + // zero-terminate + *write = 0; + } + + struct xpath_variable_boolean: xpath_variable + { + xpath_variable_boolean(): value(false) + { + } + + bool value; + char_t name[1]; + }; + + struct xpath_variable_number: xpath_variable + { + xpath_variable_number(): value(0) + { + } + + double value; + char_t name[1]; + }; + + struct xpath_variable_string: xpath_variable + { + xpath_variable_string(): value(0) + { + } + + ~xpath_variable_string() + { + if (value) xml_memory::deallocate(value); + } + + char_t* value; + char_t name[1]; + }; + + struct xpath_variable_node_set: xpath_variable + { + xpath_node_set value; + char_t name[1]; + }; + + static const xpath_node_set dummy_node_set; + + PUGI__FN unsigned int hash_string(const char_t* str) + { + // Jenkins one-at-a-time hash (http://en.wikipedia.org/wiki/Jenkins_hash_function#one-at-a-time) + unsigned int result = 0; + + while (*str) + { + result += static_cast(*str++); + result += result << 10; + result ^= result >> 6; + } + + result += result << 3; + result ^= result >> 11; + result += result << 15; + + return result; + } + + template PUGI__FN T* new_xpath_variable(const char_t* name) + { + size_t length = strlength(name); + if (length == 0) return 0; // empty variable names are invalid + + // $$ we can't use offsetof(T, name) because T is non-POD, so we just allocate additional length characters + void* memory = xml_memory::allocate(sizeof(T) + length * sizeof(char_t)); + if (!memory) return 0; + + T* result = new (memory) T(); + + memcpy(result->name, name, (length + 1) * sizeof(char_t)); + + return result; + } + + PUGI__FN xpath_variable* new_xpath_variable(xpath_value_type type, const char_t* name) + { + switch (type) + { + case xpath_type_node_set: + return new_xpath_variable(name); + + case xpath_type_number: + return new_xpath_variable(name); + + case xpath_type_string: + return new_xpath_variable(name); + + case xpath_type_boolean: + return new_xpath_variable(name); + + default: + return 0; + } + } + + template PUGI__FN void delete_xpath_variable(T* var) + { + var->~T(); + xml_memory::deallocate(var); + } + + PUGI__FN void delete_xpath_variable(xpath_value_type type, xpath_variable* var) + { + switch (type) + { + case xpath_type_node_set: + delete_xpath_variable(static_cast(var)); + break; + + case xpath_type_number: + delete_xpath_variable(static_cast(var)); + break; + + case xpath_type_string: + delete_xpath_variable(static_cast(var)); + break; + + case xpath_type_boolean: + delete_xpath_variable(static_cast(var)); + break; + + default: + assert(!"Invalid variable type"); + } + } + + PUGI__FN xpath_variable* get_variable(xpath_variable_set* set, const char_t* begin, const char_t* end) + { + char_t buffer[32]; + + size_t length = static_cast(end - begin); + char_t* scratch = buffer; + + if (length >= sizeof(buffer) / sizeof(buffer[0])) + { + // need to make dummy on-heap copy + scratch = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!scratch) return 0; + } + + // copy string to zero-terminated buffer and perform lookup + memcpy(scratch, begin, length * sizeof(char_t)); + scratch[length] = 0; + + xpath_variable* result = set->get(scratch); + + // free dummy buffer + if (scratch != buffer) xml_memory::deallocate(scratch); + + return result; + } +PUGI__NS_END + +// Internal node set class +PUGI__NS_BEGIN + PUGI__FN xpath_node_set::type_t xpath_sort(xpath_node* begin, xpath_node* end, xpath_node_set::type_t type, bool rev) + { + xpath_node_set::type_t order = rev ? xpath_node_set::type_sorted_reverse : xpath_node_set::type_sorted; + + if (type == xpath_node_set::type_unsorted) + { + sort(begin, end, document_order_comparator()); + + type = xpath_node_set::type_sorted; + } + + if (type != order) reverse(begin, end); + + return order; + } + + PUGI__FN xpath_node xpath_first(const xpath_node* begin, const xpath_node* end, xpath_node_set::type_t type) + { + if (begin == end) return xpath_node(); + + switch (type) + { + case xpath_node_set::type_sorted: + return *begin; + + case xpath_node_set::type_sorted_reverse: + return *(end - 1); + + case xpath_node_set::type_unsorted: + return *min_element(begin, end, document_order_comparator()); + + default: + assert(!"Invalid node set type"); + return xpath_node(); + } + } + + class xpath_node_set_raw + { + xpath_node_set::type_t _type; + + xpath_node* _begin; + xpath_node* _end; + xpath_node* _eos; + + public: + xpath_node_set_raw(): _type(xpath_node_set::type_unsorted), _begin(0), _end(0), _eos(0) + { + } + + xpath_node* begin() const + { + return _begin; + } + + xpath_node* end() const + { + return _end; + } + + bool empty() const + { + return _begin == _end; + } + + size_t size() const + { + return static_cast(_end - _begin); + } + + xpath_node first() const + { + return xpath_first(_begin, _end, _type); + } + + void push_back(const xpath_node& node, xpath_allocator* alloc) + { + if (_end == _eos) + { + size_t capacity = static_cast(_eos - _begin); + + // get new capacity (1.5x rule) + size_t new_capacity = capacity + capacity / 2 + 1; + + // reallocate the old array or allocate a new one + xpath_node* data = static_cast(alloc->reallocate(_begin, capacity * sizeof(xpath_node), new_capacity * sizeof(xpath_node))); + assert(data); + + // finalize + _begin = data; + _end = data + capacity; + _eos = data + new_capacity; + } + + *_end++ = node; + } + + void append(const xpath_node* begin_, const xpath_node* end_, xpath_allocator* alloc) + { + size_t size_ = static_cast(_end - _begin); + size_t capacity = static_cast(_eos - _begin); + size_t count = static_cast(end_ - begin_); + + if (size_ + count > capacity) + { + // reallocate the old array or allocate a new one + xpath_node* data = static_cast(alloc->reallocate(_begin, capacity * sizeof(xpath_node), (size_ + count) * sizeof(xpath_node))); + assert(data); + + // finalize + _begin = data; + _end = data + size_; + _eos = data + size_ + count; + } + + memcpy(_end, begin_, count * sizeof(xpath_node)); + _end += count; + } + + void sort_do() + { + _type = xpath_sort(_begin, _end, _type, false); + } + + void truncate(xpath_node* pos) + { + assert(_begin <= pos && pos <= _end); + + _end = pos; + } + + void remove_duplicates() + { + if (_type == xpath_node_set::type_unsorted) + sort(_begin, _end, duplicate_comparator()); + + _end = unique(_begin, _end); + } + + xpath_node_set::type_t type() const + { + return _type; + } + + void set_type(xpath_node_set::type_t value) + { + _type = value; + } + }; +PUGI__NS_END + +PUGI__NS_BEGIN + struct xpath_context + { + xpath_node n; + size_t position, size; + + xpath_context(const xpath_node& n_, size_t position_, size_t size_): n(n_), position(position_), size(size_) + { + } + }; + + enum lexeme_t + { + lex_none = 0, + lex_equal, + lex_not_equal, + lex_less, + lex_greater, + lex_less_or_equal, + lex_greater_or_equal, + lex_plus, + lex_minus, + lex_multiply, + lex_union, + lex_var_ref, + lex_open_brace, + lex_close_brace, + lex_quoted_string, + lex_number, + lex_slash, + lex_double_slash, + lex_open_square_brace, + lex_close_square_brace, + lex_string, + lex_comma, + lex_axis_attribute, + lex_dot, + lex_double_dot, + lex_double_colon, + lex_eof + }; + + struct xpath_lexer_string + { + const char_t* begin; + const char_t* end; + + xpath_lexer_string(): begin(0), end(0) + { + } + + bool operator==(const char_t* other) const + { + size_t length = static_cast(end - begin); + + return strequalrange(other, begin, length); + } + }; + + class xpath_lexer + { + const char_t* _cur; + const char_t* _cur_lexeme_pos; + xpath_lexer_string _cur_lexeme_contents; + + lexeme_t _cur_lexeme; + + public: + explicit xpath_lexer(const char_t* query): _cur(query) + { + next(); + } + + const char_t* state() const + { + return _cur; + } + + void next() + { + const char_t* cur = _cur; + + while (PUGI__IS_CHARTYPE(*cur, ct_space)) ++cur; + + // save lexeme position for error reporting + _cur_lexeme_pos = cur; + + switch (*cur) + { + case 0: + _cur_lexeme = lex_eof; + break; + + case '>': + if (*(cur+1) == '=') + { + cur += 2; + _cur_lexeme = lex_greater_or_equal; + } + else + { + cur += 1; + _cur_lexeme = lex_greater; + } + break; + + case '<': + if (*(cur+1) == '=') + { + cur += 2; + _cur_lexeme = lex_less_or_equal; + } + else + { + cur += 1; + _cur_lexeme = lex_less; + } + break; + + case '!': + if (*(cur+1) == '=') + { + cur += 2; + _cur_lexeme = lex_not_equal; + } + else + { + _cur_lexeme = lex_none; + } + break; + + case '=': + cur += 1; + _cur_lexeme = lex_equal; + + break; + + case '+': + cur += 1; + _cur_lexeme = lex_plus; + + break; + + case '-': + cur += 1; + _cur_lexeme = lex_minus; + + break; + + case '*': + cur += 1; + _cur_lexeme = lex_multiply; + + break; + + case '|': + cur += 1; + _cur_lexeme = lex_union; + + break; + + case '$': + cur += 1; + + if (PUGI__IS_CHARTYPEX(*cur, ctx_start_symbol)) + { + _cur_lexeme_contents.begin = cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + + if (cur[0] == ':' && PUGI__IS_CHARTYPEX(cur[1], ctx_symbol)) // qname + { + cur++; // : + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + } + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_var_ref; + } + else + { + _cur_lexeme = lex_none; + } + + break; + + case '(': + cur += 1; + _cur_lexeme = lex_open_brace; + + break; + + case ')': + cur += 1; + _cur_lexeme = lex_close_brace; + + break; + + case '[': + cur += 1; + _cur_lexeme = lex_open_square_brace; + + break; + + case ']': + cur += 1; + _cur_lexeme = lex_close_square_brace; + + break; + + case ',': + cur += 1; + _cur_lexeme = lex_comma; + + break; + + case '/': + if (*(cur+1) == '/') + { + cur += 2; + _cur_lexeme = lex_double_slash; + } + else + { + cur += 1; + _cur_lexeme = lex_slash; + } + break; + + case '.': + if (*(cur+1) == '.') + { + cur += 2; + _cur_lexeme = lex_double_dot; + } + else if (PUGI__IS_CHARTYPEX(*(cur+1), ctx_digit)) + { + _cur_lexeme_contents.begin = cur; // . + + ++cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) cur++; + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_number; + } + else + { + cur += 1; + _cur_lexeme = lex_dot; + } + break; + + case '@': + cur += 1; + _cur_lexeme = lex_axis_attribute; + + break; + + case '"': + case '\'': + { + char_t terminator = *cur; + + ++cur; + + _cur_lexeme_contents.begin = cur; + while (*cur && *cur != terminator) cur++; + _cur_lexeme_contents.end = cur; + + if (!*cur) + _cur_lexeme = lex_none; + else + { + cur += 1; + _cur_lexeme = lex_quoted_string; + } + + break; + } + + case ':': + if (*(cur+1) == ':') + { + cur += 2; + _cur_lexeme = lex_double_colon; + } + else + { + _cur_lexeme = lex_none; + } + break; + + default: + if (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) + { + _cur_lexeme_contents.begin = cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) cur++; + + if (*cur == '.') + { + cur++; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) cur++; + } + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_number; + } + else if (PUGI__IS_CHARTYPEX(*cur, ctx_start_symbol)) + { + _cur_lexeme_contents.begin = cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + + if (cur[0] == ':') + { + if (cur[1] == '*') // namespace test ncname:* + { + cur += 2; // :* + } + else if (PUGI__IS_CHARTYPEX(cur[1], ctx_symbol)) // namespace test qname + { + cur++; // : + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + } + } + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_string; + } + else + { + _cur_lexeme = lex_none; + } + } + + _cur = cur; + } + + lexeme_t current() const + { + return _cur_lexeme; + } + + const char_t* current_pos() const + { + return _cur_lexeme_pos; + } + + const xpath_lexer_string& contents() const + { + assert(_cur_lexeme == lex_var_ref || _cur_lexeme == lex_number || _cur_lexeme == lex_string || _cur_lexeme == lex_quoted_string); + + return _cur_lexeme_contents; + } + }; + + enum ast_type_t + { + ast_op_or, // left or right + ast_op_and, // left and right + ast_op_equal, // left = right + ast_op_not_equal, // left != right + ast_op_less, // left < right + ast_op_greater, // left > right + ast_op_less_or_equal, // left <= right + ast_op_greater_or_equal, // left >= right + ast_op_add, // left + right + ast_op_subtract, // left - right + ast_op_multiply, // left * right + ast_op_divide, // left / right + ast_op_mod, // left % right + ast_op_negate, // left - right + ast_op_union, // left | right + ast_predicate, // apply predicate to set; next points to next predicate + ast_filter, // select * from left where right + ast_filter_posinv, // select * from left where right; proximity position invariant + ast_string_constant, // string constant + ast_number_constant, // number constant + ast_variable, // variable + ast_func_last, // last() + ast_func_position, // position() + ast_func_count, // count(left) + ast_func_id, // id(left) + ast_func_local_name_0, // local-name() + ast_func_local_name_1, // local-name(left) + ast_func_namespace_uri_0, // namespace-uri() + ast_func_namespace_uri_1, // namespace-uri(left) + ast_func_name_0, // name() + ast_func_name_1, // name(left) + ast_func_string_0, // string() + ast_func_string_1, // string(left) + ast_func_concat, // concat(left, right, siblings) + ast_func_starts_with, // starts_with(left, right) + ast_func_contains, // contains(left, right) + ast_func_substring_before, // substring-before(left, right) + ast_func_substring_after, // substring-after(left, right) + ast_func_substring_2, // substring(left, right) + ast_func_substring_3, // substring(left, right, third) + ast_func_string_length_0, // string-length() + ast_func_string_length_1, // string-length(left) + ast_func_normalize_space_0, // normalize-space() + ast_func_normalize_space_1, // normalize-space(left) + ast_func_translate, // translate(left, right, third) + ast_func_boolean, // boolean(left) + ast_func_not, // not(left) + ast_func_true, // true() + ast_func_false, // false() + ast_func_lang, // lang(left) + ast_func_number_0, // number() + ast_func_number_1, // number(left) + ast_func_sum, // sum(left) + ast_func_floor, // floor(left) + ast_func_ceiling, // ceiling(left) + ast_func_round, // round(left) + ast_step, // process set left with step + ast_step_root // select root node + }; + + enum axis_t + { + axis_ancestor, + axis_ancestor_or_self, + axis_attribute, + axis_child, + axis_descendant, + axis_descendant_or_self, + axis_following, + axis_following_sibling, + axis_namespace, + axis_parent, + axis_preceding, + axis_preceding_sibling, + axis_self + }; + + enum nodetest_t + { + nodetest_none, + nodetest_name, + nodetest_type_node, + nodetest_type_comment, + nodetest_type_pi, + nodetest_type_text, + nodetest_pi, + nodetest_all, + nodetest_all_in_namespace + }; + + template struct axis_to_type + { + static const axis_t axis; + }; + + template const axis_t axis_to_type::axis = N; + + class xpath_ast_node + { + private: + // node type + char _type; + char _rettype; + + // for ast_step / ast_predicate + char _axis; + char _test; + + // tree node structure + xpath_ast_node* _left; + xpath_ast_node* _right; + xpath_ast_node* _next; + + union + { + // value for ast_string_constant + const char_t* string; + // value for ast_number_constant + double number; + // variable for ast_variable + xpath_variable* variable; + // node test for ast_step (node name/namespace/node type/pi target) + const char_t* nodetest; + } _data; + + xpath_ast_node(const xpath_ast_node&); + xpath_ast_node& operator=(const xpath_ast_node&); + + template static bool compare_eq(xpath_ast_node* lhs, xpath_ast_node* rhs, const xpath_context& c, const xpath_stack& stack, const Comp& comp) + { + xpath_value_type lt = lhs->rettype(), rt = rhs->rettype(); + + if (lt != xpath_type_node_set && rt != xpath_type_node_set) + { + if (lt == xpath_type_boolean || rt == xpath_type_boolean) + return comp(lhs->eval_boolean(c, stack), rhs->eval_boolean(c, stack)); + else if (lt == xpath_type_number || rt == xpath_type_number) + return comp(lhs->eval_number(c, stack), rhs->eval_number(c, stack)); + else if (lt == xpath_type_string || rt == xpath_type_string) + { + xpath_allocator_capture cr(stack.result); + + xpath_string ls = lhs->eval_string(c, stack); + xpath_string rs = rhs->eval_string(c, stack); + + return comp(ls, rs); + } + } + else if (lt == xpath_type_node_set && rt == xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ls = lhs->eval_node_set(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack); + + for (const xpath_node* li = ls.begin(); li != ls.end(); ++li) + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(string_value(*li, stack.result), string_value(*ri, stack.result))) + return true; + } + + return false; + } + else + { + if (lt == xpath_type_node_set) + { + swap(lhs, rhs); + swap(lt, rt); + } + + if (lt == xpath_type_boolean) + return comp(lhs->eval_boolean(c, stack), rhs->eval_boolean(c, stack)); + else if (lt == xpath_type_number) + { + xpath_allocator_capture cr(stack.result); + + double l = lhs->eval_number(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(l, convert_string_to_number(string_value(*ri, stack.result).c_str()))) + return true; + } + + return false; + } + else if (lt == xpath_type_string) + { + xpath_allocator_capture cr(stack.result); + + xpath_string l = lhs->eval_string(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(l, string_value(*ri, stack.result))) + return true; + } + + return false; + } + } + + assert(!"Wrong types"); + return false; + } + + template static bool compare_rel(xpath_ast_node* lhs, xpath_ast_node* rhs, const xpath_context& c, const xpath_stack& stack, const Comp& comp) + { + xpath_value_type lt = lhs->rettype(), rt = rhs->rettype(); + + if (lt != xpath_type_node_set && rt != xpath_type_node_set) + return comp(lhs->eval_number(c, stack), rhs->eval_number(c, stack)); + else if (lt == xpath_type_node_set && rt == xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ls = lhs->eval_node_set(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack); + + for (const xpath_node* li = ls.begin(); li != ls.end(); ++li) + { + xpath_allocator_capture cri(stack.result); + + double l = convert_string_to_number(string_value(*li, stack.result).c_str()); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture crii(stack.result); + + if (comp(l, convert_string_to_number(string_value(*ri, stack.result).c_str()))) + return true; + } + } + + return false; + } + else if (lt != xpath_type_node_set && rt == xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + double l = lhs->eval_number(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(l, convert_string_to_number(string_value(*ri, stack.result).c_str()))) + return true; + } + + return false; + } + else if (lt == xpath_type_node_set && rt != xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ls = lhs->eval_node_set(c, stack); + double r = rhs->eval_number(c, stack); + + for (const xpath_node* li = ls.begin(); li != ls.end(); ++li) + { + xpath_allocator_capture cri(stack.result); + + if (comp(convert_string_to_number(string_value(*li, stack.result).c_str()), r)) + return true; + } + + return false; + } + else + { + assert(!"Wrong types"); + return false; + } + } + + void apply_predicate(xpath_node_set_raw& ns, size_t first, xpath_ast_node* expr, const xpath_stack& stack) + { + assert(ns.size() >= first); + + size_t i = 1; + size_t size = ns.size() - first; + + xpath_node* last = ns.begin() + first; + + // remove_if... or well, sort of + for (xpath_node* it = last; it != ns.end(); ++it, ++i) + { + xpath_context c(*it, i, size); + + if (expr->rettype() == xpath_type_number) + { + if (expr->eval_number(c, stack) == i) + *last++ = *it; + } + else if (expr->eval_boolean(c, stack)) + *last++ = *it; + } + + ns.truncate(last); + } + + void apply_predicates(xpath_node_set_raw& ns, size_t first, const xpath_stack& stack) + { + if (ns.size() == first) return; + + for (xpath_ast_node* pred = _right; pred; pred = pred->_next) + { + apply_predicate(ns, first, pred->_left, stack); + } + } + + void step_push(xpath_node_set_raw& ns, const xml_attribute& a, const xml_node& parent, xpath_allocator* alloc) + { + if (!a) return; + + const char_t* name = a.name(); + + // There are no attribute nodes corresponding to attributes that declare namespaces + // That is, "xmlns:..." or "xmlns" + if (starts_with(name, PUGIXML_TEXT("xmlns")) && (name[5] == 0 || name[5] == ':')) return; + + switch (_test) + { + case nodetest_name: + if (strequal(name, _data.nodetest)) ns.push_back(xpath_node(a, parent), alloc); + break; + + case nodetest_type_node: + case nodetest_all: + ns.push_back(xpath_node(a, parent), alloc); + break; + + case nodetest_all_in_namespace: + if (starts_with(name, _data.nodetest)) + ns.push_back(xpath_node(a, parent), alloc); + break; + + default: + ; + } + } + + void step_push(xpath_node_set_raw& ns, const xml_node& n, xpath_allocator* alloc) + { + if (!n) return; + + switch (_test) + { + case nodetest_name: + if (n.type() == node_element && strequal(n.name(), _data.nodetest)) ns.push_back(n, alloc); + break; + + case nodetest_type_node: + ns.push_back(n, alloc); + break; + + case nodetest_type_comment: + if (n.type() == node_comment) + ns.push_back(n, alloc); + break; + + case nodetest_type_text: + if (n.type() == node_pcdata || n.type() == node_cdata) + ns.push_back(n, alloc); + break; + + case nodetest_type_pi: + if (n.type() == node_pi) + ns.push_back(n, alloc); + break; + + case nodetest_pi: + if (n.type() == node_pi && strequal(n.name(), _data.nodetest)) + ns.push_back(n, alloc); + break; + + case nodetest_all: + if (n.type() == node_element) + ns.push_back(n, alloc); + break; + + case nodetest_all_in_namespace: + if (n.type() == node_element && starts_with(n.name(), _data.nodetest)) + ns.push_back(n, alloc); + break; + + default: + assert(!"Unknown axis"); + } + } + + template void step_fill(xpath_node_set_raw& ns, const xml_node& n, xpath_allocator* alloc, T) + { + const axis_t axis = T::axis; + + switch (axis) + { + case axis_attribute: + { + for (xml_attribute a = n.first_attribute(); a; a = a.next_attribute()) + step_push(ns, a, n, alloc); + + break; + } + + case axis_child: + { + for (xml_node c = n.first_child(); c; c = c.next_sibling()) + step_push(ns, c, alloc); + + break; + } + + case axis_descendant: + case axis_descendant_or_self: + { + if (axis == axis_descendant_or_self) + step_push(ns, n, alloc); + + xml_node cur = n.first_child(); + + while (cur && cur != n) + { + step_push(ns, cur, alloc); + + if (cur.first_child()) + cur = cur.first_child(); + else if (cur.next_sibling()) + cur = cur.next_sibling(); + else + { + while (!cur.next_sibling() && cur != n) + cur = cur.parent(); + + if (cur != n) cur = cur.next_sibling(); + } + } + + break; + } + + case axis_following_sibling: + { + for (xml_node c = n.next_sibling(); c; c = c.next_sibling()) + step_push(ns, c, alloc); + + break; + } + + case axis_preceding_sibling: + { + for (xml_node c = n.previous_sibling(); c; c = c.previous_sibling()) + step_push(ns, c, alloc); + + break; + } + + case axis_following: + { + xml_node cur = n; + + // exit from this node so that we don't include descendants + while (cur && !cur.next_sibling()) cur = cur.parent(); + cur = cur.next_sibling(); + + for (;;) + { + step_push(ns, cur, alloc); + + if (cur.first_child()) + cur = cur.first_child(); + else if (cur.next_sibling()) + cur = cur.next_sibling(); + else + { + while (cur && !cur.next_sibling()) cur = cur.parent(); + cur = cur.next_sibling(); + + if (!cur) break; + } + } + + break; + } + + case axis_preceding: + { + xml_node cur = n; + + while (cur && !cur.previous_sibling()) cur = cur.parent(); + cur = cur.previous_sibling(); + + for (;;) + { + if (cur.last_child()) + cur = cur.last_child(); + else + { + // leaf node, can't be ancestor + step_push(ns, cur, alloc); + + if (cur.previous_sibling()) + cur = cur.previous_sibling(); + else + { + do + { + cur = cur.parent(); + if (!cur) break; + + if (!node_is_ancestor(cur, n)) step_push(ns, cur, alloc); + } + while (!cur.previous_sibling()); + + cur = cur.previous_sibling(); + + if (!cur) break; + } + } + } + + break; + } + + case axis_ancestor: + case axis_ancestor_or_self: + { + if (axis == axis_ancestor_or_self) + step_push(ns, n, alloc); + + xml_node cur = n.parent(); + + while (cur) + { + step_push(ns, cur, alloc); + + cur = cur.parent(); + } + + break; + } + + case axis_self: + { + step_push(ns, n, alloc); + + break; + } + + case axis_parent: + { + if (n.parent()) step_push(ns, n.parent(), alloc); + + break; + } + + default: + assert(!"Unimplemented axis"); + } + } + + template void step_fill(xpath_node_set_raw& ns, const xml_attribute& a, const xml_node& p, xpath_allocator* alloc, T v) + { + const axis_t axis = T::axis; + + switch (axis) + { + case axis_ancestor: + case axis_ancestor_or_self: + { + if (axis == axis_ancestor_or_self && _test == nodetest_type_node) // reject attributes based on principal node type test + step_push(ns, a, p, alloc); + + xml_node cur = p; + + while (cur) + { + step_push(ns, cur, alloc); + + cur = cur.parent(); + } + + break; + } + + case axis_descendant_or_self: + case axis_self: + { + if (_test == nodetest_type_node) // reject attributes based on principal node type test + step_push(ns, a, p, alloc); + + break; + } + + case axis_following: + { + xml_node cur = p; + + for (;;) + { + if (cur.first_child()) + cur = cur.first_child(); + else if (cur.next_sibling()) + cur = cur.next_sibling(); + else + { + while (cur && !cur.next_sibling()) cur = cur.parent(); + cur = cur.next_sibling(); + + if (!cur) break; + } + + step_push(ns, cur, alloc); + } + + break; + } + + case axis_parent: + { + step_push(ns, p, alloc); + + break; + } + + case axis_preceding: + { + // preceding:: axis does not include attribute nodes and attribute ancestors (they are the same as parent's ancestors), so we can reuse node preceding + step_fill(ns, p, alloc, v); + break; + } + + default: + assert(!"Unimplemented axis"); + } + } + + template xpath_node_set_raw step_do(const xpath_context& c, const xpath_stack& stack, T v) + { + const axis_t axis = T::axis; + bool attributes = (axis == axis_ancestor || axis == axis_ancestor_or_self || axis == axis_descendant_or_self || axis == axis_following || axis == axis_parent || axis == axis_preceding || axis == axis_self); + + xpath_node_set_raw ns; + ns.set_type((axis == axis_ancestor || axis == axis_ancestor_or_self || axis == axis_preceding || axis == axis_preceding_sibling) ? xpath_node_set::type_sorted_reverse : xpath_node_set::type_sorted); + + if (_left) + { + xpath_node_set_raw s = _left->eval_node_set(c, stack); + + // self axis preserves the original order + if (axis == axis_self) ns.set_type(s.type()); + + for (const xpath_node* it = s.begin(); it != s.end(); ++it) + { + size_t size = ns.size(); + + // in general, all axes generate elements in a particular order, but there is no order guarantee if axis is applied to two nodes + if (axis != axis_self && size != 0) ns.set_type(xpath_node_set::type_unsorted); + + if (it->node()) + step_fill(ns, it->node(), stack.result, v); + else if (attributes) + step_fill(ns, it->attribute(), it->parent(), stack.result, v); + + apply_predicates(ns, size, stack); + } + } + else + { + if (c.n.node()) + step_fill(ns, c.n.node(), stack.result, v); + else if (attributes) + step_fill(ns, c.n.attribute(), c.n.parent(), stack.result, v); + + apply_predicates(ns, 0, stack); + } + + // child, attribute and self axes always generate unique set of nodes + // for other axis, if the set stayed sorted, it stayed unique because the traversal algorithms do not visit the same node twice + if (axis != axis_child && axis != axis_attribute && axis != axis_self && ns.type() == xpath_node_set::type_unsorted) + ns.remove_duplicates(); + + return ns; + } + + public: + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, const char_t* value): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(0), _right(0), _next(0) + { + assert(type == ast_string_constant); + _data.string = value; + } + + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, double value): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(0), _right(0), _next(0) + { + assert(type == ast_number_constant); + _data.number = value; + } + + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, xpath_variable* value): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(0), _right(0), _next(0) + { + assert(type == ast_variable); + _data.variable = value; + } + + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, xpath_ast_node* left = 0, xpath_ast_node* right = 0): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(left), _right(right), _next(0) + { + } + + xpath_ast_node(ast_type_t type, xpath_ast_node* left, axis_t axis, nodetest_t test, const char_t* contents): + _type(static_cast(type)), _rettype(xpath_type_node_set), _axis(static_cast(axis)), _test(static_cast(test)), _left(left), _right(0), _next(0) + { + _data.nodetest = contents; + } + + void set_next(xpath_ast_node* value) + { + _next = value; + } + + void set_right(xpath_ast_node* value) + { + _right = value; + } + + bool eval_boolean(const xpath_context& c, const xpath_stack& stack) + { + switch (_type) + { + case ast_op_or: + return _left->eval_boolean(c, stack) || _right->eval_boolean(c, stack); + + case ast_op_and: + return _left->eval_boolean(c, stack) && _right->eval_boolean(c, stack); + + case ast_op_equal: + return compare_eq(_left, _right, c, stack, equal_to()); + + case ast_op_not_equal: + return compare_eq(_left, _right, c, stack, not_equal_to()); + + case ast_op_less: + return compare_rel(_left, _right, c, stack, less()); + + case ast_op_greater: + return compare_rel(_right, _left, c, stack, less()); + + case ast_op_less_or_equal: + return compare_rel(_left, _right, c, stack, less_equal()); + + case ast_op_greater_or_equal: + return compare_rel(_right, _left, c, stack, less_equal()); + + case ast_func_starts_with: + { + xpath_allocator_capture cr(stack.result); + + xpath_string lr = _left->eval_string(c, stack); + xpath_string rr = _right->eval_string(c, stack); + + return starts_with(lr.c_str(), rr.c_str()); + } + + case ast_func_contains: + { + xpath_allocator_capture cr(stack.result); + + xpath_string lr = _left->eval_string(c, stack); + xpath_string rr = _right->eval_string(c, stack); + + return find_substring(lr.c_str(), rr.c_str()) != 0; + } + + case ast_func_boolean: + return _left->eval_boolean(c, stack); + + case ast_func_not: + return !_left->eval_boolean(c, stack); + + case ast_func_true: + return true; + + case ast_func_false: + return false; + + case ast_func_lang: + { + if (c.n.attribute()) return false; + + xpath_allocator_capture cr(stack.result); + + xpath_string lang = _left->eval_string(c, stack); + + for (xml_node n = c.n.node(); n; n = n.parent()) + { + xml_attribute a = n.attribute(PUGIXML_TEXT("xml:lang")); + + if (a) + { + const char_t* value = a.value(); + + // strnicmp / strncasecmp is not portable + for (const char_t* lit = lang.c_str(); *lit; ++lit) + { + if (tolower_ascii(*lit) != tolower_ascii(*value)) return false; + ++value; + } + + return *value == 0 || *value == '-'; + } + } + + return false; + } + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_boolean) + return _data.variable->get_boolean(); + + // fallthrough to type conversion + } + + default: + { + switch (_rettype) + { + case xpath_type_number: + return convert_number_to_boolean(eval_number(c, stack)); + + case xpath_type_string: + { + xpath_allocator_capture cr(stack.result); + + return !eval_string(c, stack).empty(); + } + + case xpath_type_node_set: + { + xpath_allocator_capture cr(stack.result); + + return !eval_node_set(c, stack).empty(); + } + + default: + assert(!"Wrong expression for return type boolean"); + return false; + } + } + } + } + + double eval_number(const xpath_context& c, const xpath_stack& stack) + { + switch (_type) + { + case ast_op_add: + return _left->eval_number(c, stack) + _right->eval_number(c, stack); + + case ast_op_subtract: + return _left->eval_number(c, stack) - _right->eval_number(c, stack); + + case ast_op_multiply: + return _left->eval_number(c, stack) * _right->eval_number(c, stack); + + case ast_op_divide: + return _left->eval_number(c, stack) / _right->eval_number(c, stack); + + case ast_op_mod: + return fmod(_left->eval_number(c, stack), _right->eval_number(c, stack)); + + case ast_op_negate: + return -_left->eval_number(c, stack); + + case ast_number_constant: + return _data.number; + + case ast_func_last: + return static_cast(c.size); + + case ast_func_position: + return static_cast(c.position); + + case ast_func_count: + { + xpath_allocator_capture cr(stack.result); + + return static_cast(_left->eval_node_set(c, stack).size()); + } + + case ast_func_string_length_0: + { + xpath_allocator_capture cr(stack.result); + + return static_cast(string_value(c.n, stack.result).length()); + } + + case ast_func_string_length_1: + { + xpath_allocator_capture cr(stack.result); + + return static_cast(_left->eval_string(c, stack).length()); + } + + case ast_func_number_0: + { + xpath_allocator_capture cr(stack.result); + + return convert_string_to_number(string_value(c.n, stack.result).c_str()); + } + + case ast_func_number_1: + return _left->eval_number(c, stack); + + case ast_func_sum: + { + xpath_allocator_capture cr(stack.result); + + double r = 0; + + xpath_node_set_raw ns = _left->eval_node_set(c, stack); + + for (const xpath_node* it = ns.begin(); it != ns.end(); ++it) + { + xpath_allocator_capture cri(stack.result); + + r += convert_string_to_number(string_value(*it, stack.result).c_str()); + } + + return r; + } + + case ast_func_floor: + { + double r = _left->eval_number(c, stack); + + return r == r ? floor(r) : r; + } + + case ast_func_ceiling: + { + double r = _left->eval_number(c, stack); + + return r == r ? ceil(r) : r; + } + + case ast_func_round: + return round_nearest_nzero(_left->eval_number(c, stack)); + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_number) + return _data.variable->get_number(); + + // fallthrough to type conversion + } + + default: + { + switch (_rettype) + { + case xpath_type_boolean: + return eval_boolean(c, stack) ? 1 : 0; + + case xpath_type_string: + { + xpath_allocator_capture cr(stack.result); + + return convert_string_to_number(eval_string(c, stack).c_str()); + } + + case xpath_type_node_set: + { + xpath_allocator_capture cr(stack.result); + + return convert_string_to_number(eval_string(c, stack).c_str()); + } + + default: + assert(!"Wrong expression for return type number"); + return 0; + } + + } + } + } + + xpath_string eval_string_concat(const xpath_context& c, const xpath_stack& stack) + { + assert(_type == ast_func_concat); + + xpath_allocator_capture ct(stack.temp); + + // count the string number + size_t count = 1; + for (xpath_ast_node* nc = _right; nc; nc = nc->_next) count++; + + // gather all strings + xpath_string static_buffer[4]; + xpath_string* buffer = static_buffer; + + // allocate on-heap for large concats + if (count > sizeof(static_buffer) / sizeof(static_buffer[0])) + { + buffer = static_cast(stack.temp->allocate(count * sizeof(xpath_string))); + assert(buffer); + } + + // evaluate all strings to temporary stack + xpath_stack swapped_stack = {stack.temp, stack.result}; + + buffer[0] = _left->eval_string(c, swapped_stack); + + size_t pos = 1; + for (xpath_ast_node* n = _right; n; n = n->_next, ++pos) buffer[pos] = n->eval_string(c, swapped_stack); + assert(pos == count); + + // get total length + size_t length = 0; + for (size_t i = 0; i < count; ++i) length += buffer[i].length(); + + // create final string + char_t* result = static_cast(stack.result->allocate((length + 1) * sizeof(char_t))); + assert(result); + + char_t* ri = result; + + for (size_t j = 0; j < count; ++j) + for (const char_t* bi = buffer[j].c_str(); *bi; ++bi) + *ri++ = *bi; + + *ri = 0; + + return xpath_string(result, true); + } + + xpath_string eval_string(const xpath_context& c, const xpath_stack& stack) + { + switch (_type) + { + case ast_string_constant: + return xpath_string_const(_data.string); + + case ast_func_local_name_0: + { + xpath_node na = c.n; + + return xpath_string_const(local_name(na)); + } + + case ast_func_local_name_1: + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ns = _left->eval_node_set(c, stack); + xpath_node na = ns.first(); + + return xpath_string_const(local_name(na)); + } + + case ast_func_name_0: + { + xpath_node na = c.n; + + return xpath_string_const(qualified_name(na)); + } + + case ast_func_name_1: + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ns = _left->eval_node_set(c, stack); + xpath_node na = ns.first(); + + return xpath_string_const(qualified_name(na)); + } + + case ast_func_namespace_uri_0: + { + xpath_node na = c.n; + + return xpath_string_const(namespace_uri(na)); + } + + case ast_func_namespace_uri_1: + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ns = _left->eval_node_set(c, stack); + xpath_node na = ns.first(); + + return xpath_string_const(namespace_uri(na)); + } + + case ast_func_string_0: + return string_value(c.n, stack.result); + + case ast_func_string_1: + return _left->eval_string(c, stack); + + case ast_func_concat: + return eval_string_concat(c, stack); + + case ast_func_substring_before: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + xpath_string p = _right->eval_string(c, swapped_stack); + + const char_t* pos = find_substring(s.c_str(), p.c_str()); + + return pos ? xpath_string(s.c_str(), pos, stack.result) : xpath_string(); + } + + case ast_func_substring_after: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + xpath_string p = _right->eval_string(c, swapped_stack); + + const char_t* pos = find_substring(s.c_str(), p.c_str()); + if (!pos) return xpath_string(); + + const char_t* result = pos + p.length(); + + return s.uses_heap() ? xpath_string(result, stack.result) : xpath_string_const(result); + } + + case ast_func_substring_2: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + size_t s_length = s.length(); + + double first = round_nearest(_right->eval_number(c, stack)); + + if (is_nan(first)) return xpath_string(); // NaN + else if (first >= s_length + 1) return xpath_string(); + + size_t pos = first < 1 ? 1 : static_cast(first); + assert(1 <= pos && pos <= s_length + 1); + + const char_t* rbegin = s.c_str() + (pos - 1); + + return s.uses_heap() ? xpath_string(rbegin, stack.result) : xpath_string_const(rbegin); + } + + case ast_func_substring_3: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + size_t s_length = s.length(); + + double first = round_nearest(_right->eval_number(c, stack)); + double last = first + round_nearest(_right->_next->eval_number(c, stack)); + + if (is_nan(first) || is_nan(last)) return xpath_string(); + else if (first >= s_length + 1) return xpath_string(); + else if (first >= last) return xpath_string(); + else if (last < 1) return xpath_string(); + + size_t pos = first < 1 ? 1 : static_cast(first); + size_t end = last >= s_length + 1 ? s_length + 1 : static_cast(last); + + assert(1 <= pos && pos <= end && end <= s_length + 1); + const char_t* rbegin = s.c_str() + (pos - 1); + const char_t* rend = s.c_str() + (end - 1); + + return (end == s_length + 1 && !s.uses_heap()) ? xpath_string_const(rbegin) : xpath_string(rbegin, rend, stack.result); + } + + case ast_func_normalize_space_0: + { + xpath_string s = string_value(c.n, stack.result); + + normalize_space(s.data(stack.result)); + + return s; + } + + case ast_func_normalize_space_1: + { + xpath_string s = _left->eval_string(c, stack); + + normalize_space(s.data(stack.result)); + + return s; + } + + case ast_func_translate: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, stack); + xpath_string from = _right->eval_string(c, swapped_stack); + xpath_string to = _right->_next->eval_string(c, swapped_stack); + + translate(s.data(stack.result), from.c_str(), to.c_str()); + + return s; + } + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_string) + return xpath_string_const(_data.variable->get_string()); + + // fallthrough to type conversion + } + + default: + { + switch (_rettype) + { + case xpath_type_boolean: + return xpath_string_const(eval_boolean(c, stack) ? PUGIXML_TEXT("true") : PUGIXML_TEXT("false")); + + case xpath_type_number: + return convert_number_to_string(eval_number(c, stack), stack.result); + + case xpath_type_node_set: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_node_set_raw ns = eval_node_set(c, swapped_stack); + return ns.empty() ? xpath_string() : string_value(ns.first(), stack.result); + } + + default: + assert(!"Wrong expression for return type string"); + return xpath_string(); + } + } + } + } + + xpath_node_set_raw eval_node_set(const xpath_context& c, const xpath_stack& stack) + { + switch (_type) + { + case ast_op_union: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_node_set_raw ls = _left->eval_node_set(c, swapped_stack); + xpath_node_set_raw rs = _right->eval_node_set(c, stack); + + // we can optimize merging two sorted sets, but this is a very rare operation, so don't bother + rs.set_type(xpath_node_set::type_unsorted); + + rs.append(ls.begin(), ls.end(), stack.result); + rs.remove_duplicates(); + + return rs; + } + + case ast_filter: + case ast_filter_posinv: + { + xpath_node_set_raw set = _left->eval_node_set(c, stack); + + // either expression is a number or it contains position() call; sort by document order + if (_type == ast_filter) set.sort_do(); + + apply_predicate(set, 0, _right, stack); + + return set; + } + + case ast_func_id: + return xpath_node_set_raw(); + + case ast_step: + { + switch (_axis) + { + case axis_ancestor: + return step_do(c, stack, axis_to_type()); + + case axis_ancestor_or_self: + return step_do(c, stack, axis_to_type()); + + case axis_attribute: + return step_do(c, stack, axis_to_type()); + + case axis_child: + return step_do(c, stack, axis_to_type()); + + case axis_descendant: + return step_do(c, stack, axis_to_type()); + + case axis_descendant_or_self: + return step_do(c, stack, axis_to_type()); + + case axis_following: + return step_do(c, stack, axis_to_type()); + + case axis_following_sibling: + return step_do(c, stack, axis_to_type()); + + case axis_namespace: + // namespaced axis is not supported + return xpath_node_set_raw(); + + case axis_parent: + return step_do(c, stack, axis_to_type()); + + case axis_preceding: + return step_do(c, stack, axis_to_type()); + + case axis_preceding_sibling: + return step_do(c, stack, axis_to_type()); + + case axis_self: + return step_do(c, stack, axis_to_type()); + + default: + assert(!"Unknown axis"); + return xpath_node_set_raw(); + } + } + + case ast_step_root: + { + assert(!_right); // root step can't have any predicates + + xpath_node_set_raw ns; + + ns.set_type(xpath_node_set::type_sorted); + + if (c.n.node()) ns.push_back(c.n.node().root(), stack.result); + else if (c.n.attribute()) ns.push_back(c.n.parent().root(), stack.result); + + return ns; + } + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_node_set) + { + const xpath_node_set& s = _data.variable->get_node_set(); + + xpath_node_set_raw ns; + + ns.set_type(s.type()); + ns.append(s.begin(), s.end(), stack.result); + + return ns; + } + + // fallthrough to type conversion + } + + default: + assert(!"Wrong expression for return type node set"); + return xpath_node_set_raw(); + } + } + + bool is_posinv() + { + switch (_type) + { + case ast_func_position: + return false; + + case ast_string_constant: + case ast_number_constant: + case ast_variable: + return true; + + case ast_step: + case ast_step_root: + return true; + + case ast_predicate: + case ast_filter: + case ast_filter_posinv: + return true; + + default: + if (_left && !_left->is_posinv()) return false; + + for (xpath_ast_node* n = _right; n; n = n->_next) + if (!n->is_posinv()) return false; + + return true; + } + } + + xpath_value_type rettype() const + { + return static_cast(_rettype); + } + }; + + struct xpath_parser + { + xpath_allocator* _alloc; + xpath_lexer _lexer; + + const char_t* _query; + xpath_variable_set* _variables; + + xpath_parse_result* _result; + + #ifdef PUGIXML_NO_EXCEPTIONS + jmp_buf _error_handler; + #endif + + void throw_error(const char* message) + { + _result->error = message; + _result->offset = _lexer.current_pos() - _query; + + #ifdef PUGIXML_NO_EXCEPTIONS + longjmp(_error_handler, 1); + #else + throw xpath_exception(*_result); + #endif + } + + void throw_error_oom() + { + #ifdef PUGIXML_NO_EXCEPTIONS + throw_error("Out of memory"); + #else + throw std::bad_alloc(); + #endif + } + + void* alloc_node() + { + void* result = _alloc->allocate_nothrow(sizeof(xpath_ast_node)); + + if (!result) throw_error_oom(); + + return result; + } + + const char_t* alloc_string(const xpath_lexer_string& value) + { + if (value.begin) + { + size_t length = static_cast(value.end - value.begin); + + char_t* c = static_cast(_alloc->allocate_nothrow((length + 1) * sizeof(char_t))); + if (!c) throw_error_oom(); + + memcpy(c, value.begin, length * sizeof(char_t)); + c[length] = 0; + + return c; + } + else return 0; + } + + xpath_ast_node* parse_function_helper(ast_type_t type0, ast_type_t type1, size_t argc, xpath_ast_node* args[2]) + { + assert(argc <= 1); + + if (argc == 1 && args[0]->rettype() != xpath_type_node_set) throw_error("Function has to be applied to node set"); + + return new (alloc_node()) xpath_ast_node(argc == 0 ? type0 : type1, xpath_type_string, args[0]); + } + + xpath_ast_node* parse_function(const xpath_lexer_string& name, size_t argc, xpath_ast_node* args[2]) + { + switch (name.begin[0]) + { + case 'b': + if (name == PUGIXML_TEXT("boolean") && argc == 1) + return new (alloc_node()) xpath_ast_node(ast_func_boolean, xpath_type_boolean, args[0]); + + break; + + case 'c': + if (name == PUGIXML_TEXT("count") && argc == 1) + { + if (args[0]->rettype() != xpath_type_node_set) throw_error("Function has to be applied to node set"); + return new (alloc_node()) xpath_ast_node(ast_func_count, xpath_type_number, args[0]); + } + else if (name == PUGIXML_TEXT("contains") && argc == 2) + return new (alloc_node()) xpath_ast_node(ast_func_contains, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("concat") && argc >= 2) + return new (alloc_node()) xpath_ast_node(ast_func_concat, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("ceiling") && argc == 1) + return new (alloc_node()) xpath_ast_node(ast_func_ceiling, xpath_type_number, args[0]); + + break; + + case 'f': + if (name == PUGIXML_TEXT("false") && argc == 0) + return new (alloc_node()) xpath_ast_node(ast_func_false, xpath_type_boolean); + else if (name == PUGIXML_TEXT("floor") && argc == 1) + return new (alloc_node()) xpath_ast_node(ast_func_floor, xpath_type_number, args[0]); + + break; + + case 'i': + if (name == PUGIXML_TEXT("id") && argc == 1) + return new (alloc_node()) xpath_ast_node(ast_func_id, xpath_type_node_set, args[0]); + + break; + + case 'l': + if (name == PUGIXML_TEXT("last") && argc == 0) + return new (alloc_node()) xpath_ast_node(ast_func_last, xpath_type_number); + else if (name == PUGIXML_TEXT("lang") && argc == 1) + return new (alloc_node()) xpath_ast_node(ast_func_lang, xpath_type_boolean, args[0]); + else if (name == PUGIXML_TEXT("local-name") && argc <= 1) + return parse_function_helper(ast_func_local_name_0, ast_func_local_name_1, argc, args); + + break; + + case 'n': + if (name == PUGIXML_TEXT("name") && argc <= 1) + return parse_function_helper(ast_func_name_0, ast_func_name_1, argc, args); + else if (name == PUGIXML_TEXT("namespace-uri") && argc <= 1) + return parse_function_helper(ast_func_namespace_uri_0, ast_func_namespace_uri_1, argc, args); + else if (name == PUGIXML_TEXT("normalize-space") && argc <= 1) + return new (alloc_node()) xpath_ast_node(argc == 0 ? ast_func_normalize_space_0 : ast_func_normalize_space_1, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("not") && argc == 1) + return new (alloc_node()) xpath_ast_node(ast_func_not, xpath_type_boolean, args[0]); + else if (name == PUGIXML_TEXT("number") && argc <= 1) + return new (alloc_node()) xpath_ast_node(argc == 0 ? ast_func_number_0 : ast_func_number_1, xpath_type_number, args[0]); + + break; + + case 'p': + if (name == PUGIXML_TEXT("position") && argc == 0) + return new (alloc_node()) xpath_ast_node(ast_func_position, xpath_type_number); + + break; + + case 'r': + if (name == PUGIXML_TEXT("round") && argc == 1) + return new (alloc_node()) xpath_ast_node(ast_func_round, xpath_type_number, args[0]); + + break; + + case 's': + if (name == PUGIXML_TEXT("string") && argc <= 1) + return new (alloc_node()) xpath_ast_node(argc == 0 ? ast_func_string_0 : ast_func_string_1, xpath_type_string, args[0]); + else if (name == PUGIXML_TEXT("string-length") && argc <= 1) + return new (alloc_node()) xpath_ast_node(argc == 0 ? ast_func_string_length_0 : ast_func_string_length_1, xpath_type_string, args[0]); + else if (name == PUGIXML_TEXT("starts-with") && argc == 2) + return new (alloc_node()) xpath_ast_node(ast_func_starts_with, xpath_type_boolean, args[0], args[1]); + else if (name == PUGIXML_TEXT("substring-before") && argc == 2) + return new (alloc_node()) xpath_ast_node(ast_func_substring_before, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("substring-after") && argc == 2) + return new (alloc_node()) xpath_ast_node(ast_func_substring_after, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("substring") && (argc == 2 || argc == 3)) + return new (alloc_node()) xpath_ast_node(argc == 2 ? ast_func_substring_2 : ast_func_substring_3, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("sum") && argc == 1) + { + if (args[0]->rettype() != xpath_type_node_set) throw_error("Function has to be applied to node set"); + return new (alloc_node()) xpath_ast_node(ast_func_sum, xpath_type_number, args[0]); + } + + break; + + case 't': + if (name == PUGIXML_TEXT("translate") && argc == 3) + return new (alloc_node()) xpath_ast_node(ast_func_translate, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("true") && argc == 0) + return new (alloc_node()) xpath_ast_node(ast_func_true, xpath_type_boolean); + + break; + + default: + break; + } + + throw_error("Unrecognized function or wrong parameter count"); + + return 0; + } + + axis_t parse_axis_name(const xpath_lexer_string& name, bool& specified) + { + specified = true; + + switch (name.begin[0]) + { + case 'a': + if (name == PUGIXML_TEXT("ancestor")) + return axis_ancestor; + else if (name == PUGIXML_TEXT("ancestor-or-self")) + return axis_ancestor_or_self; + else if (name == PUGIXML_TEXT("attribute")) + return axis_attribute; + + break; + + case 'c': + if (name == PUGIXML_TEXT("child")) + return axis_child; + + break; + + case 'd': + if (name == PUGIXML_TEXT("descendant")) + return axis_descendant; + else if (name == PUGIXML_TEXT("descendant-or-self")) + return axis_descendant_or_self; + + break; + + case 'f': + if (name == PUGIXML_TEXT("following")) + return axis_following; + else if (name == PUGIXML_TEXT("following-sibling")) + return axis_following_sibling; + + break; + + case 'n': + if (name == PUGIXML_TEXT("namespace")) + return axis_namespace; + + break; + + case 'p': + if (name == PUGIXML_TEXT("parent")) + return axis_parent; + else if (name == PUGIXML_TEXT("preceding")) + return axis_preceding; + else if (name == PUGIXML_TEXT("preceding-sibling")) + return axis_preceding_sibling; + + break; + + case 's': + if (name == PUGIXML_TEXT("self")) + return axis_self; + + break; + + default: + break; + } + + specified = false; + return axis_child; + } + + nodetest_t parse_node_test_type(const xpath_lexer_string& name) + { + switch (name.begin[0]) + { + case 'c': + if (name == PUGIXML_TEXT("comment")) + return nodetest_type_comment; + + break; + + case 'n': + if (name == PUGIXML_TEXT("node")) + return nodetest_type_node; + + break; + + case 'p': + if (name == PUGIXML_TEXT("processing-instruction")) + return nodetest_type_pi; + + break; + + case 't': + if (name == PUGIXML_TEXT("text")) + return nodetest_type_text; + + break; + + default: + break; + } + + return nodetest_none; + } + + // PrimaryExpr ::= VariableReference | '(' Expr ')' | Literal | Number | FunctionCall + xpath_ast_node* parse_primary_expression() + { + switch (_lexer.current()) + { + case lex_var_ref: + { + xpath_lexer_string name = _lexer.contents(); + + if (!_variables) + throw_error("Unknown variable: variable set is not provided"); + + xpath_variable* var = get_variable(_variables, name.begin, name.end); + + if (!var) + throw_error("Unknown variable: variable set does not contain the given name"); + + _lexer.next(); + + return new (alloc_node()) xpath_ast_node(ast_variable, var->type(), var); + } + + case lex_open_brace: + { + _lexer.next(); + + xpath_ast_node* n = parse_expression(); + + if (_lexer.current() != lex_close_brace) + throw_error("Unmatched braces"); + + _lexer.next(); + + return n; + } + + case lex_quoted_string: + { + const char_t* value = alloc_string(_lexer.contents()); + + xpath_ast_node* n = new (alloc_node()) xpath_ast_node(ast_string_constant, xpath_type_string, value); + _lexer.next(); + + return n; + } + + case lex_number: + { + double value = 0; + + if (!convert_string_to_number(_lexer.contents().begin, _lexer.contents().end, &value)) + throw_error_oom(); + + xpath_ast_node* n = new (alloc_node()) xpath_ast_node(ast_number_constant, xpath_type_number, value); + _lexer.next(); + + return n; + } + + case lex_string: + { + xpath_ast_node* args[2] = {0}; + size_t argc = 0; + + xpath_lexer_string function = _lexer.contents(); + _lexer.next(); + + xpath_ast_node* last_arg = 0; + + if (_lexer.current() != lex_open_brace) + throw_error("Unrecognized function call"); + _lexer.next(); + + if (_lexer.current() != lex_close_brace) + args[argc++] = parse_expression(); + + while (_lexer.current() != lex_close_brace) + { + if (_lexer.current() != lex_comma) + throw_error("No comma between function arguments"); + _lexer.next(); + + xpath_ast_node* n = parse_expression(); + + if (argc < 2) args[argc] = n; + else last_arg->set_next(n); + + argc++; + last_arg = n; + } + + _lexer.next(); + + return parse_function(function, argc, args); + } + + default: + throw_error("Unrecognizable primary expression"); + + return 0; + } + } + + // FilterExpr ::= PrimaryExpr | FilterExpr Predicate + // Predicate ::= '[' PredicateExpr ']' + // PredicateExpr ::= Expr + xpath_ast_node* parse_filter_expression() + { + xpath_ast_node* n = parse_primary_expression(); + + while (_lexer.current() == lex_open_square_brace) + { + _lexer.next(); + + xpath_ast_node* expr = parse_expression(); + + if (n->rettype() != xpath_type_node_set) throw_error("Predicate has to be applied to node set"); + + bool posinv = expr->rettype() != xpath_type_number && expr->is_posinv(); + + n = new (alloc_node()) xpath_ast_node(posinv ? ast_filter_posinv : ast_filter, xpath_type_node_set, n, expr); + + if (_lexer.current() != lex_close_square_brace) + throw_error("Unmatched square brace"); + + _lexer.next(); + } + + return n; + } + + // Step ::= AxisSpecifier NodeTest Predicate* | AbbreviatedStep + // AxisSpecifier ::= AxisName '::' | '@'? + // NodeTest ::= NameTest | NodeType '(' ')' | 'processing-instruction' '(' Literal ')' + // NameTest ::= '*' | NCName ':' '*' | QName + // AbbreviatedStep ::= '.' | '..' + xpath_ast_node* parse_step(xpath_ast_node* set) + { + if (set && set->rettype() != xpath_type_node_set) + throw_error("Step has to be applied to node set"); + + bool axis_specified = false; + axis_t axis = axis_child; // implied child axis + + if (_lexer.current() == lex_axis_attribute) + { + axis = axis_attribute; + axis_specified = true; + + _lexer.next(); + } + else if (_lexer.current() == lex_dot) + { + _lexer.next(); + + return new (alloc_node()) xpath_ast_node(ast_step, set, axis_self, nodetest_type_node, 0); + } + else if (_lexer.current() == lex_double_dot) + { + _lexer.next(); + + return new (alloc_node()) xpath_ast_node(ast_step, set, axis_parent, nodetest_type_node, 0); + } + + nodetest_t nt_type = nodetest_none; + xpath_lexer_string nt_name; + + if (_lexer.current() == lex_string) + { + // node name test + nt_name = _lexer.contents(); + _lexer.next(); + + // was it an axis name? + if (_lexer.current() == lex_double_colon) + { + // parse axis name + if (axis_specified) throw_error("Two axis specifiers in one step"); + + axis = parse_axis_name(nt_name, axis_specified); + + if (!axis_specified) throw_error("Unknown axis"); + + // read actual node test + _lexer.next(); + + if (_lexer.current() == lex_multiply) + { + nt_type = nodetest_all; + nt_name = xpath_lexer_string(); + _lexer.next(); + } + else if (_lexer.current() == lex_string) + { + nt_name = _lexer.contents(); + _lexer.next(); + } + else throw_error("Unrecognized node test"); + } + + if (nt_type == nodetest_none) + { + // node type test or processing-instruction + if (_lexer.current() == lex_open_brace) + { + _lexer.next(); + + if (_lexer.current() == lex_close_brace) + { + _lexer.next(); + + nt_type = parse_node_test_type(nt_name); + + if (nt_type == nodetest_none) throw_error("Unrecognized node type"); + + nt_name = xpath_lexer_string(); + } + else if (nt_name == PUGIXML_TEXT("processing-instruction")) + { + if (_lexer.current() != lex_quoted_string) + throw_error("Only literals are allowed as arguments to processing-instruction()"); + + nt_type = nodetest_pi; + nt_name = _lexer.contents(); + _lexer.next(); + + if (_lexer.current() != lex_close_brace) + throw_error("Unmatched brace near processing-instruction()"); + _lexer.next(); + } + else + throw_error("Unmatched brace near node type test"); + + } + // QName or NCName:* + else + { + if (nt_name.end - nt_name.begin > 2 && nt_name.end[-2] == ':' && nt_name.end[-1] == '*') // NCName:* + { + nt_name.end--; // erase * + + nt_type = nodetest_all_in_namespace; + } + else nt_type = nodetest_name; + } + } + } + else if (_lexer.current() == lex_multiply) + { + nt_type = nodetest_all; + _lexer.next(); + } + else throw_error("Unrecognized node test"); + + xpath_ast_node* n = new (alloc_node()) xpath_ast_node(ast_step, set, axis, nt_type, alloc_string(nt_name)); + + xpath_ast_node* last = 0; + + while (_lexer.current() == lex_open_square_brace) + { + _lexer.next(); + + xpath_ast_node* expr = parse_expression(); + + xpath_ast_node* pred = new (alloc_node()) xpath_ast_node(ast_predicate, xpath_type_node_set, expr); + + if (_lexer.current() != lex_close_square_brace) + throw_error("Unmatched square brace"); + _lexer.next(); + + if (last) last->set_next(pred); + else n->set_right(pred); + + last = pred; + } + + return n; + } + + // RelativeLocationPath ::= Step | RelativeLocationPath '/' Step | RelativeLocationPath '//' Step + xpath_ast_node* parse_relative_location_path(xpath_ast_node* set) + { + xpath_ast_node* n = parse_step(set); + + while (_lexer.current() == lex_slash || _lexer.current() == lex_double_slash) + { + lexeme_t l = _lexer.current(); + _lexer.next(); + + if (l == lex_double_slash) + n = new (alloc_node()) xpath_ast_node(ast_step, n, axis_descendant_or_self, nodetest_type_node, 0); + + n = parse_step(n); + } + + return n; + } + + // LocationPath ::= RelativeLocationPath | AbsoluteLocationPath + // AbsoluteLocationPath ::= '/' RelativeLocationPath? | '//' RelativeLocationPath + xpath_ast_node* parse_location_path() + { + if (_lexer.current() == lex_slash) + { + _lexer.next(); + + xpath_ast_node* n = new (alloc_node()) xpath_ast_node(ast_step_root, xpath_type_node_set); + + // relative location path can start from axis_attribute, dot, double_dot, multiply and string lexemes; any other lexeme means standalone root path + lexeme_t l = _lexer.current(); + + if (l == lex_string || l == lex_axis_attribute || l == lex_dot || l == lex_double_dot || l == lex_multiply) + return parse_relative_location_path(n); + else + return n; + } + else if (_lexer.current() == lex_double_slash) + { + _lexer.next(); + + xpath_ast_node* n = new (alloc_node()) xpath_ast_node(ast_step_root, xpath_type_node_set); + n = new (alloc_node()) xpath_ast_node(ast_step, n, axis_descendant_or_self, nodetest_type_node, 0); + + return parse_relative_location_path(n); + } + + // else clause moved outside of if because of bogus warning 'control may reach end of non-void function being inlined' in gcc 4.0.1 + return parse_relative_location_path(0); + } + + // PathExpr ::= LocationPath + // | FilterExpr + // | FilterExpr '/' RelativeLocationPath + // | FilterExpr '//' RelativeLocationPath + xpath_ast_node* parse_path_expression() + { + // Clarification. + // PathExpr begins with either LocationPath or FilterExpr. + // FilterExpr begins with PrimaryExpr + // PrimaryExpr begins with '$' in case of it being a variable reference, + // '(' in case of it being an expression, string literal, number constant or + // function call. + + if (_lexer.current() == lex_var_ref || _lexer.current() == lex_open_brace || + _lexer.current() == lex_quoted_string || _lexer.current() == lex_number || + _lexer.current() == lex_string) + { + if (_lexer.current() == lex_string) + { + // This is either a function call, or not - if not, we shall proceed with location path + const char_t* state = _lexer.state(); + + while (PUGI__IS_CHARTYPE(*state, ct_space)) ++state; + + if (*state != '(') return parse_location_path(); + + // This looks like a function call; however this still can be a node-test. Check it. + if (parse_node_test_type(_lexer.contents()) != nodetest_none) return parse_location_path(); + } + + xpath_ast_node* n = parse_filter_expression(); + + if (_lexer.current() == lex_slash || _lexer.current() == lex_double_slash) + { + lexeme_t l = _lexer.current(); + _lexer.next(); + + if (l == lex_double_slash) + { + if (n->rettype() != xpath_type_node_set) throw_error("Step has to be applied to node set"); + + n = new (alloc_node()) xpath_ast_node(ast_step, n, axis_descendant_or_self, nodetest_type_node, 0); + } + + // select from location path + return parse_relative_location_path(n); + } + + return n; + } + else return parse_location_path(); + } + + // UnionExpr ::= PathExpr | UnionExpr '|' PathExpr + xpath_ast_node* parse_union_expression() + { + xpath_ast_node* n = parse_path_expression(); + + while (_lexer.current() == lex_union) + { + _lexer.next(); + + xpath_ast_node* expr = parse_union_expression(); + + if (n->rettype() != xpath_type_node_set || expr->rettype() != xpath_type_node_set) + throw_error("Union operator has to be applied to node sets"); + + n = new (alloc_node()) xpath_ast_node(ast_op_union, xpath_type_node_set, n, expr); + } + + return n; + } + + // UnaryExpr ::= UnionExpr | '-' UnaryExpr + xpath_ast_node* parse_unary_expression() + { + if (_lexer.current() == lex_minus) + { + _lexer.next(); + + xpath_ast_node* expr = parse_unary_expression(); + + return new (alloc_node()) xpath_ast_node(ast_op_negate, xpath_type_number, expr); + } + else return parse_union_expression(); + } + + // MultiplicativeExpr ::= UnaryExpr + // | MultiplicativeExpr '*' UnaryExpr + // | MultiplicativeExpr 'div' UnaryExpr + // | MultiplicativeExpr 'mod' UnaryExpr + xpath_ast_node* parse_multiplicative_expression() + { + xpath_ast_node* n = parse_unary_expression(); + + while (_lexer.current() == lex_multiply || (_lexer.current() == lex_string && + (_lexer.contents() == PUGIXML_TEXT("mod") || _lexer.contents() == PUGIXML_TEXT("div")))) + { + ast_type_t op = _lexer.current() == lex_multiply ? ast_op_multiply : + _lexer.contents().begin[0] == 'd' ? ast_op_divide : ast_op_mod; + _lexer.next(); + + xpath_ast_node* expr = parse_unary_expression(); + + n = new (alloc_node()) xpath_ast_node(op, xpath_type_number, n, expr); + } + + return n; + } + + // AdditiveExpr ::= MultiplicativeExpr + // | AdditiveExpr '+' MultiplicativeExpr + // | AdditiveExpr '-' MultiplicativeExpr + xpath_ast_node* parse_additive_expression() + { + xpath_ast_node* n = parse_multiplicative_expression(); + + while (_lexer.current() == lex_plus || _lexer.current() == lex_minus) + { + lexeme_t l = _lexer.current(); + + _lexer.next(); + + xpath_ast_node* expr = parse_multiplicative_expression(); + + n = new (alloc_node()) xpath_ast_node(l == lex_plus ? ast_op_add : ast_op_subtract, xpath_type_number, n, expr); + } + + return n; + } + + // RelationalExpr ::= AdditiveExpr + // | RelationalExpr '<' AdditiveExpr + // | RelationalExpr '>' AdditiveExpr + // | RelationalExpr '<=' AdditiveExpr + // | RelationalExpr '>=' AdditiveExpr + xpath_ast_node* parse_relational_expression() + { + xpath_ast_node* n = parse_additive_expression(); + + while (_lexer.current() == lex_less || _lexer.current() == lex_less_or_equal || + _lexer.current() == lex_greater || _lexer.current() == lex_greater_or_equal) + { + lexeme_t l = _lexer.current(); + _lexer.next(); + + xpath_ast_node* expr = parse_additive_expression(); + + n = new (alloc_node()) xpath_ast_node(l == lex_less ? ast_op_less : l == lex_greater ? ast_op_greater : + l == lex_less_or_equal ? ast_op_less_or_equal : ast_op_greater_or_equal, xpath_type_boolean, n, expr); + } + + return n; + } + + // EqualityExpr ::= RelationalExpr + // | EqualityExpr '=' RelationalExpr + // | EqualityExpr '!=' RelationalExpr + xpath_ast_node* parse_equality_expression() + { + xpath_ast_node* n = parse_relational_expression(); + + while (_lexer.current() == lex_equal || _lexer.current() == lex_not_equal) + { + lexeme_t l = _lexer.current(); + + _lexer.next(); + + xpath_ast_node* expr = parse_relational_expression(); + + n = new (alloc_node()) xpath_ast_node(l == lex_equal ? ast_op_equal : ast_op_not_equal, xpath_type_boolean, n, expr); + } + + return n; + } + + // AndExpr ::= EqualityExpr | AndExpr 'and' EqualityExpr + xpath_ast_node* parse_and_expression() + { + xpath_ast_node* n = parse_equality_expression(); + + while (_lexer.current() == lex_string && _lexer.contents() == PUGIXML_TEXT("and")) + { + _lexer.next(); + + xpath_ast_node* expr = parse_equality_expression(); + + n = new (alloc_node()) xpath_ast_node(ast_op_and, xpath_type_boolean, n, expr); + } + + return n; + } + + // OrExpr ::= AndExpr | OrExpr 'or' AndExpr + xpath_ast_node* parse_or_expression() + { + xpath_ast_node* n = parse_and_expression(); + + while (_lexer.current() == lex_string && _lexer.contents() == PUGIXML_TEXT("or")) + { + _lexer.next(); + + xpath_ast_node* expr = parse_and_expression(); + + n = new (alloc_node()) xpath_ast_node(ast_op_or, xpath_type_boolean, n, expr); + } + + return n; + } + + // Expr ::= OrExpr + xpath_ast_node* parse_expression() + { + return parse_or_expression(); + } + + xpath_parser(const char_t* query, xpath_variable_set* variables, xpath_allocator* alloc, xpath_parse_result* result): _alloc(alloc), _lexer(query), _query(query), _variables(variables), _result(result) + { + } + + xpath_ast_node* parse() + { + xpath_ast_node* result = parse_expression(); + + if (_lexer.current() != lex_eof) + { + // there are still unparsed tokens left, error + throw_error("Incorrect query"); + } + + return result; + } + + static xpath_ast_node* parse(const char_t* query, xpath_variable_set* variables, xpath_allocator* alloc, xpath_parse_result* result) + { + xpath_parser parser(query, variables, alloc, result); + + #ifdef PUGIXML_NO_EXCEPTIONS + int error = setjmp(parser._error_handler); + + return (error == 0) ? parser.parse() : 0; + #else + return parser.parse(); + #endif + } + }; + + struct xpath_query_impl + { + static xpath_query_impl* create() + { + void* memory = xml_memory::allocate(sizeof(xpath_query_impl)); + + return new (memory) xpath_query_impl(); + } + + static void destroy(void* ptr) + { + if (!ptr) return; + + // free all allocated pages + static_cast(ptr)->alloc.release(); + + // free allocator memory (with the first page) + xml_memory::deallocate(ptr); + } + + xpath_query_impl(): root(0), alloc(&block) + { + block.next = 0; + } + + xpath_ast_node* root; + xpath_allocator alloc; + xpath_memory_block block; + }; + + PUGI__FN xpath_string evaluate_string_impl(xpath_query_impl* impl, const xpath_node& n, xpath_stack_data& sd) + { + if (!impl) return xpath_string(); + + #ifdef PUGIXML_NO_EXCEPTIONS + if (setjmp(sd.error_handler)) return xpath_string(); + #endif + + xpath_context c(n, 1, 1); + + return impl->root->eval_string(c, sd.stack); + } +PUGI__NS_END + +namespace pugi +{ +#ifndef PUGIXML_NO_EXCEPTIONS + PUGI__FN xpath_exception::xpath_exception(const xpath_parse_result& result_): _result(result_) + { + assert(_result.error); + } + + PUGI__FN const char* xpath_exception::what() const throw() + { + return _result.error; + } + + PUGI__FN const xpath_parse_result& xpath_exception::result() const + { + return _result; + } +#endif + + PUGI__FN xpath_node::xpath_node() + { + } + + PUGI__FN xpath_node::xpath_node(const xml_node& node_): _node(node_) + { + } + + PUGI__FN xpath_node::xpath_node(const xml_attribute& attribute_, const xml_node& parent_): _node(attribute_ ? parent_ : xml_node()), _attribute(attribute_) + { + } + + PUGI__FN xml_node xpath_node::node() const + { + return _attribute ? xml_node() : _node; + } + + PUGI__FN xml_attribute xpath_node::attribute() const + { + return _attribute; + } + + PUGI__FN xml_node xpath_node::parent() const + { + return _attribute ? _node : _node.parent(); + } + + PUGI__FN static void unspecified_bool_xpath_node(xpath_node***) + { + } + + PUGI__FN xpath_node::operator xpath_node::unspecified_bool_type() const + { + return (_node || _attribute) ? unspecified_bool_xpath_node : 0; + } + + PUGI__FN bool xpath_node::operator!() const + { + return !(_node || _attribute); + } + + PUGI__FN bool xpath_node::operator==(const xpath_node& n) const + { + return _node == n._node && _attribute == n._attribute; + } + + PUGI__FN bool xpath_node::operator!=(const xpath_node& n) const + { + return _node != n._node || _attribute != n._attribute; + } + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xpath_node& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xpath_node& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN void xpath_node_set::_assign(const_iterator begin_, const_iterator end_) + { + assert(begin_ <= end_); + + size_t size_ = static_cast(end_ - begin_); + + if (size_ <= 1) + { + // deallocate old buffer + if (_begin != &_storage) impl::xml_memory::deallocate(_begin); + + // use internal buffer + if (begin_ != end_) _storage = *begin_; + + _begin = &_storage; + _end = &_storage + size_; + } + else + { + // make heap copy + xpath_node* storage = static_cast(impl::xml_memory::allocate(size_ * sizeof(xpath_node))); + + if (!storage) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return; + #else + throw std::bad_alloc(); + #endif + } + + memcpy(storage, begin_, size_ * sizeof(xpath_node)); + + // deallocate old buffer + if (_begin != &_storage) impl::xml_memory::deallocate(_begin); + + // finalize + _begin = storage; + _end = storage + size_; + } + } + + PUGI__FN xpath_node_set::xpath_node_set(): _type(type_unsorted), _begin(&_storage), _end(&_storage) + { + } + + PUGI__FN xpath_node_set::xpath_node_set(const_iterator begin_, const_iterator end_, type_t type_): _type(type_), _begin(&_storage), _end(&_storage) + { + _assign(begin_, end_); + } + + PUGI__FN xpath_node_set::~xpath_node_set() + { + if (_begin != &_storage) impl::xml_memory::deallocate(_begin); + } + + PUGI__FN xpath_node_set::xpath_node_set(const xpath_node_set& ns): _type(ns._type), _begin(&_storage), _end(&_storage) + { + _assign(ns._begin, ns._end); + } + + PUGI__FN xpath_node_set& xpath_node_set::operator=(const xpath_node_set& ns) + { + if (this == &ns) return *this; + + _type = ns._type; + _assign(ns._begin, ns._end); + + return *this; + } + + PUGI__FN xpath_node_set::type_t xpath_node_set::type() const + { + return _type; + } + + PUGI__FN size_t xpath_node_set::size() const + { + return _end - _begin; + } + + PUGI__FN bool xpath_node_set::empty() const + { + return _begin == _end; + } + + PUGI__FN const xpath_node& xpath_node_set::operator[](size_t index) const + { + assert(index < size()); + return _begin[index]; + } + + PUGI__FN xpath_node_set::const_iterator xpath_node_set::begin() const + { + return _begin; + } + + PUGI__FN xpath_node_set::const_iterator xpath_node_set::end() const + { + return _end; + } + + PUGI__FN void xpath_node_set::sort(bool reverse) + { + _type = impl::xpath_sort(_begin, _end, _type, reverse); + } + + PUGI__FN xpath_node xpath_node_set::first() const + { + return impl::xpath_first(_begin, _end, _type); + } + + PUGI__FN xpath_parse_result::xpath_parse_result(): error("Internal error"), offset(0) + { + } + + PUGI__FN xpath_parse_result::operator bool() const + { + return error == 0; + } + + PUGI__FN const char* xpath_parse_result::description() const + { + return error ? error : "No error"; + } + + PUGI__FN xpath_variable::xpath_variable() + { + } + + PUGI__FN const char_t* xpath_variable::name() const + { + switch (_type) + { + case xpath_type_node_set: + return static_cast(this)->name; + + case xpath_type_number: + return static_cast(this)->name; + + case xpath_type_string: + return static_cast(this)->name; + + case xpath_type_boolean: + return static_cast(this)->name; + + default: + assert(!"Invalid variable type"); + return 0; + } + } + + PUGI__FN xpath_value_type xpath_variable::type() const + { + return _type; + } + + PUGI__FN bool xpath_variable::get_boolean() const + { + return (_type == xpath_type_boolean) ? static_cast(this)->value : false; + } + + PUGI__FN double xpath_variable::get_number() const + { + return (_type == xpath_type_number) ? static_cast(this)->value : impl::gen_nan(); + } + + PUGI__FN const char_t* xpath_variable::get_string() const + { + const char_t* value = (_type == xpath_type_string) ? static_cast(this)->value : 0; + return value ? value : PUGIXML_TEXT(""); + } + + PUGI__FN const xpath_node_set& xpath_variable::get_node_set() const + { + return (_type == xpath_type_node_set) ? static_cast(this)->value : impl::dummy_node_set; + } + + PUGI__FN bool xpath_variable::set(bool value) + { + if (_type != xpath_type_boolean) return false; + + static_cast(this)->value = value; + return true; + } + + PUGI__FN bool xpath_variable::set(double value) + { + if (_type != xpath_type_number) return false; + + static_cast(this)->value = value; + return true; + } + + PUGI__FN bool xpath_variable::set(const char_t* value) + { + if (_type != xpath_type_string) return false; + + impl::xpath_variable_string* var = static_cast(this); + + // duplicate string + size_t size = (impl::strlength(value) + 1) * sizeof(char_t); + + char_t* copy = static_cast(impl::xml_memory::allocate(size)); + if (!copy) return false; + + memcpy(copy, value, size); + + // replace old string + if (var->value) impl::xml_memory::deallocate(var->value); + var->value = copy; + + return true; + } + + PUGI__FN bool xpath_variable::set(const xpath_node_set& value) + { + if (_type != xpath_type_node_set) return false; + + static_cast(this)->value = value; + return true; + } + + PUGI__FN xpath_variable_set::xpath_variable_set() + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) _data[i] = 0; + } + + PUGI__FN xpath_variable_set::~xpath_variable_set() + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + { + xpath_variable* var = _data[i]; + + while (var) + { + xpath_variable* next = var->_next; + + impl::delete_xpath_variable(var->_type, var); + + var = next; + } + } + } + + PUGI__FN xpath_variable* xpath_variable_set::find(const char_t* name) const + { + const size_t hash_size = sizeof(_data) / sizeof(_data[0]); + size_t hash = impl::hash_string(name) % hash_size; + + // look for existing variable + for (xpath_variable* var = _data[hash]; var; var = var->_next) + if (impl::strequal(var->name(), name)) + return var; + + return 0; + } + + PUGI__FN xpath_variable* xpath_variable_set::add(const char_t* name, xpath_value_type type) + { + const size_t hash_size = sizeof(_data) / sizeof(_data[0]); + size_t hash = impl::hash_string(name) % hash_size; + + // look for existing variable + for (xpath_variable* var = _data[hash]; var; var = var->_next) + if (impl::strequal(var->name(), name)) + return var->type() == type ? var : 0; + + // add new variable + xpath_variable* result = impl::new_xpath_variable(type, name); + + if (result) + { + result->_type = type; + result->_next = _data[hash]; + + _data[hash] = result; + } + + return result; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, bool value) + { + xpath_variable* var = add(name, xpath_type_boolean); + return var ? var->set(value) : false; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, double value) + { + xpath_variable* var = add(name, xpath_type_number); + return var ? var->set(value) : false; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, const char_t* value) + { + xpath_variable* var = add(name, xpath_type_string); + return var ? var->set(value) : false; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, const xpath_node_set& value) + { + xpath_variable* var = add(name, xpath_type_node_set); + return var ? var->set(value) : false; + } + + PUGI__FN xpath_variable* xpath_variable_set::get(const char_t* name) + { + return find(name); + } + + PUGI__FN const xpath_variable* xpath_variable_set::get(const char_t* name) const + { + return find(name); + } + + PUGI__FN xpath_query::xpath_query(const char_t* query, xpath_variable_set* variables): _impl(0) + { + impl::xpath_query_impl* qimpl = impl::xpath_query_impl::create(); + + if (!qimpl) + { + #ifdef PUGIXML_NO_EXCEPTIONS + _result.error = "Out of memory"; + #else + throw std::bad_alloc(); + #endif + } + else + { + impl::buffer_holder impl_holder(qimpl, impl::xpath_query_impl::destroy); + + qimpl->root = impl::xpath_parser::parse(query, variables, &qimpl->alloc, &_result); + + if (qimpl->root) + { + _impl = static_cast(impl_holder.release()); + _result.error = 0; + } + } + } + + PUGI__FN xpath_query::~xpath_query() + { + impl::xpath_query_impl::destroy(_impl); + } + + PUGI__FN xpath_value_type xpath_query::return_type() const + { + if (!_impl) return xpath_type_none; + + return static_cast(_impl)->root->rettype(); + } + + PUGI__FN bool xpath_query::evaluate_boolean(const xpath_node& n) const + { + if (!_impl) return false; + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + #ifdef PUGIXML_NO_EXCEPTIONS + if (setjmp(sd.error_handler)) return false; + #endif + + return static_cast(_impl)->root->eval_boolean(c, sd.stack); + } + + PUGI__FN double xpath_query::evaluate_number(const xpath_node& n) const + { + if (!_impl) return impl::gen_nan(); + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + #ifdef PUGIXML_NO_EXCEPTIONS + if (setjmp(sd.error_handler)) return impl::gen_nan(); + #endif + + return static_cast(_impl)->root->eval_number(c, sd.stack); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN string_t xpath_query::evaluate_string(const xpath_node& n) const + { + impl::xpath_stack_data sd; + + return impl::evaluate_string_impl(static_cast(_impl), n, sd).c_str(); + } +#endif + + PUGI__FN size_t xpath_query::evaluate_string(char_t* buffer, size_t capacity, const xpath_node& n) const + { + impl::xpath_stack_data sd; + + impl::xpath_string r = impl::evaluate_string_impl(static_cast(_impl), n, sd); + + size_t full_size = r.length() + 1; + + if (capacity > 0) + { + size_t size = (full_size < capacity) ? full_size : capacity; + assert(size > 0); + + memcpy(buffer, r.c_str(), (size - 1) * sizeof(char_t)); + buffer[size - 1] = 0; + } + + return full_size; + } + + PUGI__FN xpath_node_set xpath_query::evaluate_node_set(const xpath_node& n) const + { + if (!_impl) return xpath_node_set(); + + impl::xpath_ast_node* root = static_cast(_impl)->root; + + if (root->rettype() != xpath_type_node_set) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return xpath_node_set(); + #else + xpath_parse_result res; + res.error = "Expression does not evaluate to node set"; + + throw xpath_exception(res); + #endif + } + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + #ifdef PUGIXML_NO_EXCEPTIONS + if (setjmp(sd.error_handler)) return xpath_node_set(); + #endif + + impl::xpath_node_set_raw r = root->eval_node_set(c, sd.stack); + + return xpath_node_set(r.begin(), r.end(), r.type()); + } + + PUGI__FN const xpath_parse_result& xpath_query::result() const + { + return _result; + } + + PUGI__FN static void unspecified_bool_xpath_query(xpath_query***) + { + } + + PUGI__FN xpath_query::operator xpath_query::unspecified_bool_type() const + { + return _impl ? unspecified_bool_xpath_query : 0; + } + + PUGI__FN bool xpath_query::operator!() const + { + return !_impl; + } + + PUGI__FN xpath_node xml_node::select_single_node(const char_t* query, xpath_variable_set* variables) const + { + xpath_query q(query, variables); + return select_single_node(q); + } + + PUGI__FN xpath_node xml_node::select_single_node(const xpath_query& query) const + { + xpath_node_set s = query.evaluate_node_set(*this); + return s.empty() ? xpath_node() : s.first(); + } + + PUGI__FN xpath_node_set xml_node::select_nodes(const char_t* query, xpath_variable_set* variables) const + { + xpath_query q(query, variables); + return select_nodes(q); + } + + PUGI__FN xpath_node_set xml_node::select_nodes(const xpath_query& query) const + { + return query.evaluate_node_set(*this); + } +} + +#endif + +#ifdef __BORLANDC__ +# pragma option pop +#endif + +// Intel C++ does not properly keep warning state for function templates, +// so popping warning state at the end of translation unit leads to warnings in the middle. +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +# pragma warning(pop) +#endif + +// Undefine all local macros (makes sure we're not leaking macros in header-only mode) +#undef PUGI__NO_INLINE +#undef PUGI__STATIC_ASSERT +#undef PUGI__DMC_VOLATILE +#undef PUGI__MSVC_CRT_VERSION +#undef PUGI__NS_BEGIN +#undef PUGI__NS_END +#undef PUGI__FN +#undef PUGI__FN_NO_INLINE +#undef PUGI__IS_CHARTYPE_IMPL +#undef PUGI__IS_CHARTYPE +#undef PUGI__IS_CHARTYPEX +#undef PUGI__SKIPWS +#undef PUGI__OPTSET +#undef PUGI__PUSHNODE +#undef PUGI__POPNODE +#undef PUGI__SCANFOR +#undef PUGI__SCANWHILE +#undef PUGI__ENDSEG +#undef PUGI__THROW_ERROR +#undef PUGI__CHECK_ERROR + +#endif + +/** + * Copyright (c) 2006-2012 Arseny Kapoulkine + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ diff --git a/contrib/other-builds/extract-mixed-syntax/pugixml.hpp b/contrib/other-builds/extract-mixed-syntax/pugixml.hpp new file mode 100644 index 000000000..77b4dcf47 --- /dev/null +++ b/contrib/other-builds/extract-mixed-syntax/pugixml.hpp @@ -0,0 +1,1265 @@ +/** + * pugixml parser - version 1.2 + * -------------------------------------------------------- + * Copyright (C) 2006-2012, by Arseny Kapoulkine (arseny.kapoulkine@gmail.com) + * Report bugs and download new versions at http://pugixml.org/ + * + * This library is distributed under the MIT License. See notice at the end + * of this file. + * + * This work is based on the pugxml parser, which is: + * Copyright (C) 2003, by Kristen Wegner (kristen@tima.net) + */ + +#ifndef PUGIXML_VERSION +// Define version macro; evaluates to major * 100 + minor so that it's safe to use in less-than comparisons +# define PUGIXML_VERSION 120 +#endif + +// Include user configuration file (this can define various configuration macros) +#include "pugiconfig.hpp" + +#ifndef HEADER_PUGIXML_HPP +#define HEADER_PUGIXML_HPP + +// Include stddef.h for size_t and ptrdiff_t +#include + +// Include exception header for XPath +#if !defined(PUGIXML_NO_XPATH) && !defined(PUGIXML_NO_EXCEPTIONS) +# include +#endif + +// Include STL headers +#ifndef PUGIXML_NO_STL +# include +# include +# include +#endif + +// Macro for deprecated features +#ifndef PUGIXML_DEPRECATED +# if defined(__GNUC__) +# define PUGIXML_DEPRECATED __attribute__((deprecated)) +# elif defined(_MSC_VER) && _MSC_VER >= 1300 +# define PUGIXML_DEPRECATED __declspec(deprecated) +# else +# define PUGIXML_DEPRECATED +# endif +#endif + +// If no API is defined, assume default +#ifndef PUGIXML_API +# define PUGIXML_API +#endif + +// If no API for classes is defined, assume default +#ifndef PUGIXML_CLASS +# define PUGIXML_CLASS PUGIXML_API +#endif + +// If no API for functions is defined, assume default +#ifndef PUGIXML_FUNCTION +# define PUGIXML_FUNCTION PUGIXML_API +#endif + +// Character interface macros +#ifdef PUGIXML_WCHAR_MODE +# define PUGIXML_TEXT(t) L ## t +# define PUGIXML_CHAR wchar_t +#else +# define PUGIXML_TEXT(t) t +# define PUGIXML_CHAR char +#endif + +namespace pugi +{ + // Character type used for all internal storage and operations; depends on PUGIXML_WCHAR_MODE + typedef PUGIXML_CHAR char_t; + +#ifndef PUGIXML_NO_STL + // String type used for operations that work with STL string; depends on PUGIXML_WCHAR_MODE + typedef std::basic_string, std::allocator > string_t; +#endif +} + +// The PugiXML namespace +namespace pugi +{ + // Tree node types + enum xml_node_type + { + node_null, // Empty (null) node handle + node_document, // A document tree's absolute root + node_element, // Element tag, i.e. '' + node_pcdata, // Plain character data, i.e. 'text' + node_cdata, // Character data, i.e. '' + node_comment, // Comment tag, i.e. '' + node_pi, // Processing instruction, i.e. '' + node_declaration, // Document declaration, i.e. '' + node_doctype // Document type declaration, i.e. '' + }; + + // Parsing options + + // Minimal parsing mode (equivalent to turning all other flags off). + // Only elements and PCDATA sections are added to the DOM tree, no text conversions are performed. + const unsigned int parse_minimal = 0x0000; + + // This flag determines if processing instructions (node_pi) are added to the DOM tree. This flag is off by default. + const unsigned int parse_pi = 0x0001; + + // This flag determines if comments (node_comment) are added to the DOM tree. This flag is off by default. + const unsigned int parse_comments = 0x0002; + + // This flag determines if CDATA sections (node_cdata) are added to the DOM tree. This flag is on by default. + const unsigned int parse_cdata = 0x0004; + + // This flag determines if plain character data (node_pcdata) that consist only of whitespace are added to the DOM tree. + // This flag is off by default; turning it on usually results in slower parsing and more memory consumption. + const unsigned int parse_ws_pcdata = 0x0008; + + // This flag determines if character and entity references are expanded during parsing. This flag is on by default. + const unsigned int parse_escapes = 0x0010; + + // This flag determines if EOL characters are normalized (converted to #xA) during parsing. This flag is on by default. + const unsigned int parse_eol = 0x0020; + + // This flag determines if attribute values are normalized using CDATA normalization rules during parsing. This flag is on by default. + const unsigned int parse_wconv_attribute = 0x0040; + + // This flag determines if attribute values are normalized using NMTOKENS normalization rules during parsing. This flag is off by default. + const unsigned int parse_wnorm_attribute = 0x0080; + + // This flag determines if document declaration (node_declaration) is added to the DOM tree. This flag is off by default. + const unsigned int parse_declaration = 0x0100; + + // This flag determines if document type declaration (node_doctype) is added to the DOM tree. This flag is off by default. + const unsigned int parse_doctype = 0x0200; + + // This flag determines if plain character data (node_pcdata) that is the only child of the parent node and that consists only + // of whitespace is added to the DOM tree. + // This flag is off by default; turning it on may result in slower parsing and more memory consumption. + const unsigned int parse_ws_pcdata_single = 0x0400; + + // The default parsing mode. + // Elements, PCDATA and CDATA sections are added to the DOM tree, character/reference entities are expanded, + // End-of-Line characters are normalized, attribute values are normalized using CDATA normalization rules. + const unsigned int parse_default = parse_cdata | parse_escapes | parse_wconv_attribute | parse_eol; + + // The full parsing mode. + // Nodes of all types are added to the DOM tree, character/reference entities are expanded, + // End-of-Line characters are normalized, attribute values are normalized using CDATA normalization rules. + const unsigned int parse_full = parse_default | parse_pi | parse_comments | parse_declaration | parse_doctype; + + // These flags determine the encoding of input data for XML document + enum xml_encoding + { + encoding_auto, // Auto-detect input encoding using BOM or < / class xml_object_range + { + public: + typedef It const_iterator; + + xml_object_range(It b, It e): _begin(b), _end(e) + { + } + + It begin() const { return _begin; } + It end() const { return _end; } + + private: + It _begin, _end; + }; + + // Writer interface for node printing (see xml_node::print) + class PUGIXML_CLASS xml_writer + { + public: + virtual ~xml_writer() {} + + // Write memory chunk into stream/file/whatever + virtual void write(const void* data, size_t size) = 0; + }; + + // xml_writer implementation for FILE* + class PUGIXML_CLASS xml_writer_file: public xml_writer + { + public: + // Construct writer from a FILE* object; void* is used to avoid header dependencies on stdio + xml_writer_file(void* file); + + virtual void write(const void* data, size_t size); + + private: + void* file; + }; + + #ifndef PUGIXML_NO_STL + // xml_writer implementation for streams + class PUGIXML_CLASS xml_writer_stream: public xml_writer + { + public: + // Construct writer from an output stream object + xml_writer_stream(std::basic_ostream >& stream); + xml_writer_stream(std::basic_ostream >& stream); + + virtual void write(const void* data, size_t size); + + private: + std::basic_ostream >* narrow_stream; + std::basic_ostream >* wide_stream; + }; + #endif + + // A light-weight handle for manipulating attributes in DOM tree + class PUGIXML_CLASS xml_attribute + { + friend class xml_attribute_iterator; + friend class xml_node; + + private: + xml_attribute_struct* _attr; + + typedef void (*unspecified_bool_type)(xml_attribute***); + + public: + // Default constructor. Constructs an empty attribute. + xml_attribute(); + + // Constructs attribute from internal pointer + explicit xml_attribute(xml_attribute_struct* attr); + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Comparison operators (compares wrapped attribute pointers) + bool operator==(const xml_attribute& r) const; + bool operator!=(const xml_attribute& r) const; + bool operator<(const xml_attribute& r) const; + bool operator>(const xml_attribute& r) const; + bool operator<=(const xml_attribute& r) const; + bool operator>=(const xml_attribute& r) const; + + // Check if attribute is empty + bool empty() const; + + // Get attribute name/value, or "" if attribute is empty + const char_t* name() const; + const char_t* value() const; + + // Get attribute value, or the default value if attribute is empty + const char_t* as_string(const char_t* def = PUGIXML_TEXT("")) const; + + // Get attribute value as a number, or the default value if conversion did not succeed or attribute is empty + int as_int(int def = 0) const; + unsigned int as_uint(unsigned int def = 0) const; + double as_double(double def = 0) const; + float as_float(float def = 0) const; + + // Get attribute value as bool (returns true if first character is in '1tTyY' set), or the default value if attribute is empty + bool as_bool(bool def = false) const; + + // Set attribute name/value (returns false if attribute is empty or there is not enough memory) + bool set_name(const char_t* rhs); + bool set_value(const char_t* rhs); + + // Set attribute value with type conversion (numbers are converted to strings, boolean is converted to "true"/"false") + bool set_value(int rhs); + bool set_value(unsigned int rhs); + bool set_value(double rhs); + bool set_value(bool rhs); + + // Set attribute value (equivalent to set_value without error checking) + xml_attribute& operator=(const char_t* rhs); + xml_attribute& operator=(int rhs); + xml_attribute& operator=(unsigned int rhs); + xml_attribute& operator=(double rhs); + xml_attribute& operator=(bool rhs); + + // Get next/previous attribute in the attribute list of the parent node + xml_attribute next_attribute() const; + xml_attribute previous_attribute() const; + + // Get hash value (unique for handles to the same object) + size_t hash_value() const; + + // Get internal pointer + xml_attribute_struct* internal_object() const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xml_attribute& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xml_attribute& lhs, bool rhs); +#endif + + // A light-weight handle for manipulating nodes in DOM tree + class PUGIXML_CLASS xml_node + { + friend class xml_attribute_iterator; + friend class xml_node_iterator; + friend class xml_named_node_iterator; + + protected: + xml_node_struct* _root; + + typedef void (*unspecified_bool_type)(xml_node***); + + public: + // Default constructor. Constructs an empty node. + xml_node(); + + // Constructs node from internal pointer + explicit xml_node(xml_node_struct* p); + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Comparison operators (compares wrapped node pointers) + bool operator==(const xml_node& r) const; + bool operator!=(const xml_node& r) const; + bool operator<(const xml_node& r) const; + bool operator>(const xml_node& r) const; + bool operator<=(const xml_node& r) const; + bool operator>=(const xml_node& r) const; + + // Check if node is empty. + bool empty() const; + + // Get node type + xml_node_type type() const; + + // Get node name/value, or "" if node is empty or it has no name/value + const char_t* name() const; + const char_t* value() const; + + // Get attribute list + xml_attribute first_attribute() const; + xml_attribute last_attribute() const; + + // Get children list + xml_node first_child() const; + xml_node last_child() const; + + // Get next/previous sibling in the children list of the parent node + xml_node next_sibling() const; + xml_node previous_sibling() const; + + // Get parent node + xml_node parent() const; + + // Get root of DOM tree this node belongs to + xml_node root() const; + + // Get text object for the current node + xml_text text() const; + + // Get child, attribute or next/previous sibling with the specified name + xml_node child(const char_t* name) const; + xml_attribute attribute(const char_t* name) const; + xml_node next_sibling(const char_t* name) const; + xml_node previous_sibling(const char_t* name) const; + + // Get child value of current node; that is, value of the first child node of type PCDATA/CDATA + const char_t* child_value() const; + + // Get child value of child with specified name. Equivalent to child(name).child_value(). + const char_t* child_value(const char_t* name) const; + + // Set node name/value (returns false if node is empty, there is not enough memory, or node can not have name/value) + bool set_name(const char_t* rhs); + bool set_value(const char_t* rhs); + + // Add attribute with specified name. Returns added attribute, or empty attribute on errors. + xml_attribute append_attribute(const char_t* name); + xml_attribute prepend_attribute(const char_t* name); + xml_attribute insert_attribute_after(const char_t* name, const xml_attribute& attr); + xml_attribute insert_attribute_before(const char_t* name, const xml_attribute& attr); + + // Add a copy of the specified attribute. Returns added attribute, or empty attribute on errors. + xml_attribute append_copy(const xml_attribute& proto); + xml_attribute prepend_copy(const xml_attribute& proto); + xml_attribute insert_copy_after(const xml_attribute& proto, const xml_attribute& attr); + xml_attribute insert_copy_before(const xml_attribute& proto, const xml_attribute& attr); + + // Add child node with specified type. Returns added node, or empty node on errors. + xml_node append_child(xml_node_type type = node_element); + xml_node prepend_child(xml_node_type type = node_element); + xml_node insert_child_after(xml_node_type type, const xml_node& node); + xml_node insert_child_before(xml_node_type type, const xml_node& node); + + // Add child element with specified name. Returns added node, or empty node on errors. + xml_node append_child(const char_t* name); + xml_node prepend_child(const char_t* name); + xml_node insert_child_after(const char_t* name, const xml_node& node); + xml_node insert_child_before(const char_t* name, const xml_node& node); + + // Add a copy of the specified node as a child. Returns added node, or empty node on errors. + xml_node append_copy(const xml_node& proto); + xml_node prepend_copy(const xml_node& proto); + xml_node insert_copy_after(const xml_node& proto, const xml_node& node); + xml_node insert_copy_before(const xml_node& proto, const xml_node& node); + + // Remove specified attribute + bool remove_attribute(const xml_attribute& a); + bool remove_attribute(const char_t* name); + + // Remove specified child + bool remove_child(const xml_node& n); + bool remove_child(const char_t* name); + + // Find attribute using predicate. Returns first attribute for which predicate returned true. + template xml_attribute find_attribute(Predicate pred) const + { + if (!_root) return xml_attribute(); + + for (xml_attribute attrib = first_attribute(); attrib; attrib = attrib.next_attribute()) + if (pred(attrib)) + return attrib; + + return xml_attribute(); + } + + // Find child node using predicate. Returns first child for which predicate returned true. + template xml_node find_child(Predicate pred) const + { + if (!_root) return xml_node(); + + for (xml_node node = first_child(); node; node = node.next_sibling()) + if (pred(node)) + return node; + + return xml_node(); + } + + // Find node from subtree using predicate. Returns first node from subtree (depth-first), for which predicate returned true. + template xml_node find_node(Predicate pred) const + { + if (!_root) return xml_node(); + + xml_node cur = first_child(); + + while (cur._root && cur._root != _root) + { + if (pred(cur)) return cur; + + if (cur.first_child()) cur = cur.first_child(); + else if (cur.next_sibling()) cur = cur.next_sibling(); + else + { + while (!cur.next_sibling() && cur._root != _root) cur = cur.parent(); + + if (cur._root != _root) cur = cur.next_sibling(); + } + } + + return xml_node(); + } + + // Find child node by attribute name/value + xml_node find_child_by_attribute(const char_t* name, const char_t* attr_name, const char_t* attr_value) const; + xml_node find_child_by_attribute(const char_t* attr_name, const char_t* attr_value) const; + + #ifndef PUGIXML_NO_STL + // Get the absolute node path from root as a text string. + string_t path(char_t delimiter = '/') const; + #endif + + // Search for a node by path consisting of node names and . or .. elements. + xml_node first_element_by_path(const char_t* path, char_t delimiter = '/') const; + + // Recursively traverse subtree with xml_tree_walker + bool traverse(xml_tree_walker& walker); + + #ifndef PUGIXML_NO_XPATH + // Select single node by evaluating XPath query. Returns first node from the resulting node set. + xpath_node select_single_node(const char_t* query, xpath_variable_set* variables = 0) const; + xpath_node select_single_node(const xpath_query& query) const; + + // Select node set by evaluating XPath query + xpath_node_set select_nodes(const char_t* query, xpath_variable_set* variables = 0) const; + xpath_node_set select_nodes(const xpath_query& query) const; + #endif + + // Print subtree using a writer object + void print(xml_writer& writer, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto, unsigned int depth = 0) const; + + #ifndef PUGIXML_NO_STL + // Print subtree to stream + void print(std::basic_ostream >& os, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto, unsigned int depth = 0) const; + void print(std::basic_ostream >& os, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, unsigned int depth = 0) const; + #endif + + // Child nodes iterators + typedef xml_node_iterator iterator; + + iterator begin() const; + iterator end() const; + + // Attribute iterators + typedef xml_attribute_iterator attribute_iterator; + + attribute_iterator attributes_begin() const; + attribute_iterator attributes_end() const; + + // Range-based for support + xml_object_range children() const; + xml_object_range children(const char_t* name) const; + xml_object_range attributes() const; + + // Get node offset in parsed file/string (in char_t units) for debugging purposes + ptrdiff_t offset_debug() const; + + // Get hash value (unique for handles to the same object) + size_t hash_value() const; + + // Get internal pointer + xml_node_struct* internal_object() const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xml_node& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xml_node& lhs, bool rhs); +#endif + + // A helper for working with text inside PCDATA nodes + class PUGIXML_CLASS xml_text + { + friend class xml_node; + + xml_node_struct* _root; + + typedef void (*unspecified_bool_type)(xml_text***); + + explicit xml_text(xml_node_struct* root); + + xml_node_struct* _data_new(); + xml_node_struct* _data() const; + + public: + // Default constructor. Constructs an empty object. + xml_text(); + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Check if text object is empty + bool empty() const; + + // Get text, or "" if object is empty + const char_t* get() const; + + // Get text, or the default value if object is empty + const char_t* as_string(const char_t* def = PUGIXML_TEXT("")) const; + + // Get text as a number, or the default value if conversion did not succeed or object is empty + int as_int(int def = 0) const; + unsigned int as_uint(unsigned int def = 0) const; + double as_double(double def = 0) const; + float as_float(float def = 0) const; + + // Get text as bool (returns true if first character is in '1tTyY' set), or the default value if object is empty + bool as_bool(bool def = false) const; + + // Set text (returns false if object is empty or there is not enough memory) + bool set(const char_t* rhs); + + // Set text with type conversion (numbers are converted to strings, boolean is converted to "true"/"false") + bool set(int rhs); + bool set(unsigned int rhs); + bool set(double rhs); + bool set(bool rhs); + + // Set text (equivalent to set without error checking) + xml_text& operator=(const char_t* rhs); + xml_text& operator=(int rhs); + xml_text& operator=(unsigned int rhs); + xml_text& operator=(double rhs); + xml_text& operator=(bool rhs); + + // Get the data node (node_pcdata or node_cdata) for this object + xml_node data() const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xml_text& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xml_text& lhs, bool rhs); +#endif + + // Child node iterator (a bidirectional iterator over a collection of xml_node) + class PUGIXML_CLASS xml_node_iterator + { + friend class xml_node; + + private: + mutable xml_node _wrap; + xml_node _parent; + + xml_node_iterator(xml_node_struct* ref, xml_node_struct* parent); + + public: + // Iterator traits + typedef ptrdiff_t difference_type; + typedef xml_node value_type; + typedef xml_node* pointer; + typedef xml_node& reference; + + #ifndef PUGIXML_NO_STL + typedef std::bidirectional_iterator_tag iterator_category; + #endif + + // Default constructor + xml_node_iterator(); + + // Construct an iterator which points to the specified node + xml_node_iterator(const xml_node& node); + + // Iterator operators + bool operator==(const xml_node_iterator& rhs) const; + bool operator!=(const xml_node_iterator& rhs) const; + + xml_node& operator*() const; + xml_node* operator->() const; + + const xml_node_iterator& operator++(); + xml_node_iterator operator++(int); + + const xml_node_iterator& operator--(); + xml_node_iterator operator--(int); + }; + + // Attribute iterator (a bidirectional iterator over a collection of xml_attribute) + class PUGIXML_CLASS xml_attribute_iterator + { + friend class xml_node; + + private: + mutable xml_attribute _wrap; + xml_node _parent; + + xml_attribute_iterator(xml_attribute_struct* ref, xml_node_struct* parent); + + public: + // Iterator traits + typedef ptrdiff_t difference_type; + typedef xml_attribute value_type; + typedef xml_attribute* pointer; + typedef xml_attribute& reference; + + #ifndef PUGIXML_NO_STL + typedef std::bidirectional_iterator_tag iterator_category; + #endif + + // Default constructor + xml_attribute_iterator(); + + // Construct an iterator which points to the specified attribute + xml_attribute_iterator(const xml_attribute& attr, const xml_node& parent); + + // Iterator operators + bool operator==(const xml_attribute_iterator& rhs) const; + bool operator!=(const xml_attribute_iterator& rhs) const; + + xml_attribute& operator*() const; + xml_attribute* operator->() const; + + const xml_attribute_iterator& operator++(); + xml_attribute_iterator operator++(int); + + const xml_attribute_iterator& operator--(); + xml_attribute_iterator operator--(int); + }; + + // Named node range helper + class xml_named_node_iterator + { + public: + // Iterator traits + typedef ptrdiff_t difference_type; + typedef xml_node value_type; + typedef xml_node* pointer; + typedef xml_node& reference; + + #ifndef PUGIXML_NO_STL + typedef std::forward_iterator_tag iterator_category; + #endif + + // Default constructor + xml_named_node_iterator(); + + // Construct an iterator which points to the specified node + xml_named_node_iterator(const xml_node& node, const char_t* name); + + // Iterator operators + bool operator==(const xml_named_node_iterator& rhs) const; + bool operator!=(const xml_named_node_iterator& rhs) const; + + xml_node& operator*() const; + xml_node* operator->() const; + + const xml_named_node_iterator& operator++(); + xml_named_node_iterator operator++(int); + + private: + mutable xml_node _node; + const char_t* _name; + }; + + // Abstract tree walker class (see xml_node::traverse) + class PUGIXML_CLASS xml_tree_walker + { + friend class xml_node; + + private: + int _depth; + + protected: + // Get current traversal depth + int depth() const; + + public: + xml_tree_walker(); + virtual ~xml_tree_walker(); + + // Callback that is called when traversal begins + virtual bool begin(xml_node& node); + + // Callback that is called for each node traversed + virtual bool for_each(xml_node& node) = 0; + + // Callback that is called when traversal ends + virtual bool end(xml_node& node); + }; + + // Parsing status, returned as part of xml_parse_result object + enum xml_parse_status + { + status_ok = 0, // No error + + status_file_not_found, // File was not found during load_file() + status_io_error, // Error reading from file/stream + status_out_of_memory, // Could not allocate memory + status_internal_error, // Internal error occurred + + status_unrecognized_tag, // Parser could not determine tag type + + status_bad_pi, // Parsing error occurred while parsing document declaration/processing instruction + status_bad_comment, // Parsing error occurred while parsing comment + status_bad_cdata, // Parsing error occurred while parsing CDATA section + status_bad_doctype, // Parsing error occurred while parsing document type declaration + status_bad_pcdata, // Parsing error occurred while parsing PCDATA section + status_bad_start_element, // Parsing error occurred while parsing start element tag + status_bad_attribute, // Parsing error occurred while parsing element attribute + status_bad_end_element, // Parsing error occurred while parsing end element tag + status_end_element_mismatch // There was a mismatch of start-end tags (closing tag had incorrect name, some tag was not closed or there was an excessive closing tag) + }; + + // Parsing result + struct PUGIXML_CLASS xml_parse_result + { + // Parsing status (see xml_parse_status) + xml_parse_status status; + + // Last parsed offset (in char_t units from start of input data) + ptrdiff_t offset; + + // Source document encoding + xml_encoding encoding; + + // Default constructor, initializes object to failed state + xml_parse_result(); + + // Cast to bool operator + operator bool() const; + + // Get error description + const char* description() const; + }; + + // Document class (DOM tree root) + class PUGIXML_CLASS xml_document: public xml_node + { + private: + char_t* _buffer; + + char _memory[192]; + + // Non-copyable semantics + xml_document(const xml_document&); + const xml_document& operator=(const xml_document&); + + void create(); + void destroy(); + + xml_parse_result load_buffer_impl(void* contents, size_t size, unsigned int options, xml_encoding encoding, bool is_mutable, bool own); + + public: + // Default constructor, makes empty document + xml_document(); + + // Destructor, invalidates all node/attribute handles to this document + ~xml_document(); + + // Removes all nodes, leaving the empty document + void reset(); + + // Removes all nodes, then copies the entire contents of the specified document + void reset(const xml_document& proto); + + #ifndef PUGIXML_NO_STL + // Load document from stream. + xml_parse_result load(std::basic_istream >& stream, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + xml_parse_result load(std::basic_istream >& stream, unsigned int options = parse_default); + #endif + + // Load document from zero-terminated string. No encoding conversions are applied. + xml_parse_result load(const char_t* contents, unsigned int options = parse_default); + + // Load document from file + xml_parse_result load_file(const char* path, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + xml_parse_result load_file(const wchar_t* path, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Load document from buffer. Copies/converts the buffer, so it may be deleted or changed after the function returns. + xml_parse_result load_buffer(const void* contents, size_t size, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Load document from buffer, using the buffer for in-place parsing (the buffer is modified and used for storage of document data). + // You should ensure that buffer data will persist throughout the document's lifetime, and free the buffer memory manually once document is destroyed. + xml_parse_result load_buffer_inplace(void* contents, size_t size, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Load document from buffer, using the buffer for in-place parsing (the buffer is modified and used for storage of document data). + // You should allocate the buffer with pugixml allocation function; document will free the buffer when it is no longer needed (you can't use it anymore). + xml_parse_result load_buffer_inplace_own(void* contents, size_t size, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Save XML document to writer (semantics is slightly different from xml_node::print, see documentation for details). + void save(xml_writer& writer, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + + #ifndef PUGIXML_NO_STL + // Save XML document to stream (semantics is slightly different from xml_node::print, see documentation for details). + void save(std::basic_ostream >& stream, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + void save(std::basic_ostream >& stream, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default) const; + #endif + + // Save XML to file + bool save_file(const char* path, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + bool save_file(const wchar_t* path, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + + // Get document element + xml_node document_element() const; + }; + +#ifndef PUGIXML_NO_XPATH + // XPath query return type + enum xpath_value_type + { + xpath_type_none, // Unknown type (query failed to compile) + xpath_type_node_set, // Node set (xpath_node_set) + xpath_type_number, // Number + xpath_type_string, // String + xpath_type_boolean // Boolean + }; + + // XPath parsing result + struct PUGIXML_CLASS xpath_parse_result + { + // Error message (0 if no error) + const char* error; + + // Last parsed offset (in char_t units from string start) + ptrdiff_t offset; + + // Default constructor, initializes object to failed state + xpath_parse_result(); + + // Cast to bool operator + operator bool() const; + + // Get error description + const char* description() const; + }; + + // A single XPath variable + class PUGIXML_CLASS xpath_variable + { + friend class xpath_variable_set; + + protected: + xpath_value_type _type; + xpath_variable* _next; + + xpath_variable(); + + // Non-copyable semantics + xpath_variable(const xpath_variable&); + xpath_variable& operator=(const xpath_variable&); + + public: + // Get variable name + const char_t* name() const; + + // Get variable type + xpath_value_type type() const; + + // Get variable value; no type conversion is performed, default value (false, NaN, empty string, empty node set) is returned on type mismatch error + bool get_boolean() const; + double get_number() const; + const char_t* get_string() const; + const xpath_node_set& get_node_set() const; + + // Set variable value; no type conversion is performed, false is returned on type mismatch error + bool set(bool value); + bool set(double value); + bool set(const char_t* value); + bool set(const xpath_node_set& value); + }; + + // A set of XPath variables + class PUGIXML_CLASS xpath_variable_set + { + private: + xpath_variable* _data[64]; + + // Non-copyable semantics + xpath_variable_set(const xpath_variable_set&); + xpath_variable_set& operator=(const xpath_variable_set&); + + xpath_variable* find(const char_t* name) const; + + public: + // Default constructor/destructor + xpath_variable_set(); + ~xpath_variable_set(); + + // Add a new variable or get the existing one, if the types match + xpath_variable* add(const char_t* name, xpath_value_type type); + + // Set value of an existing variable; no type conversion is performed, false is returned if there is no such variable or if types mismatch + bool set(const char_t* name, bool value); + bool set(const char_t* name, double value); + bool set(const char_t* name, const char_t* value); + bool set(const char_t* name, const xpath_node_set& value); + + // Get existing variable by name + xpath_variable* get(const char_t* name); + const xpath_variable* get(const char_t* name) const; + }; + + // A compiled XPath query object + class PUGIXML_CLASS xpath_query + { + private: + void* _impl; + xpath_parse_result _result; + + typedef void (*unspecified_bool_type)(xpath_query***); + + // Non-copyable semantics + xpath_query(const xpath_query&); + xpath_query& operator=(const xpath_query&); + + public: + // Construct a compiled object from XPath expression. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws xpath_exception on compilation errors. + explicit xpath_query(const char_t* query, xpath_variable_set* variables = 0); + + // Destructor + ~xpath_query(); + + // Get query expression return type + xpath_value_type return_type() const; + + // Evaluate expression as boolean value in the specified context; performs type conversion if necessary. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + bool evaluate_boolean(const xpath_node& n) const; + + // Evaluate expression as double value in the specified context; performs type conversion if necessary. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + double evaluate_number(const xpath_node& n) const; + + #ifndef PUGIXML_NO_STL + // Evaluate expression as string value in the specified context; performs type conversion if necessary. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + string_t evaluate_string(const xpath_node& n) const; + #endif + + // Evaluate expression as string value in the specified context; performs type conversion if necessary. + // At most capacity characters are written to the destination buffer, full result size is returned (includes terminating zero). + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + // If PUGIXML_NO_EXCEPTIONS is defined, returns empty set instead. + size_t evaluate_string(char_t* buffer, size_t capacity, const xpath_node& n) const; + + // Evaluate expression as node set in the specified context. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws xpath_exception on type mismatch and std::bad_alloc on out of memory errors. + // If PUGIXML_NO_EXCEPTIONS is defined, returns empty node set instead. + xpath_node_set evaluate_node_set(const xpath_node& n) const; + + // Get parsing result (used to get compilation errors in PUGIXML_NO_EXCEPTIONS mode) + const xpath_parse_result& result() const; + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + }; + + #ifndef PUGIXML_NO_EXCEPTIONS + // XPath exception class + class PUGIXML_CLASS xpath_exception: public std::exception + { + private: + xpath_parse_result _result; + + public: + // Construct exception from parse result + explicit xpath_exception(const xpath_parse_result& result); + + // Get error message + virtual const char* what() const throw(); + + // Get parse result + const xpath_parse_result& result() const; + }; + #endif + + // XPath node class (either xml_node or xml_attribute) + class PUGIXML_CLASS xpath_node + { + private: + xml_node _node; + xml_attribute _attribute; + + typedef void (*unspecified_bool_type)(xpath_node***); + + public: + // Default constructor; constructs empty XPath node + xpath_node(); + + // Construct XPath node from XML node/attribute + xpath_node(const xml_node& node); + xpath_node(const xml_attribute& attribute, const xml_node& parent); + + // Get node/attribute, if any + xml_node node() const; + xml_attribute attribute() const; + + // Get parent of contained node/attribute + xml_node parent() const; + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Comparison operators + bool operator==(const xpath_node& n) const; + bool operator!=(const xpath_node& n) const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xpath_node& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xpath_node& lhs, bool rhs); +#endif + + // A fixed-size collection of XPath nodes + class PUGIXML_CLASS xpath_node_set + { + public: + // Collection type + enum type_t + { + type_unsorted, // Not ordered + type_sorted, // Sorted by document order (ascending) + type_sorted_reverse // Sorted by document order (descending) + }; + + // Constant iterator type + typedef const xpath_node* const_iterator; + + // Default constructor. Constructs empty set. + xpath_node_set(); + + // Constructs a set from iterator range; data is not checked for duplicates and is not sorted according to provided type, so be careful + xpath_node_set(const_iterator begin, const_iterator end, type_t type = type_unsorted); + + // Destructor + ~xpath_node_set(); + + // Copy constructor/assignment operator + xpath_node_set(const xpath_node_set& ns); + xpath_node_set& operator=(const xpath_node_set& ns); + + // Get collection type + type_t type() const; + + // Get collection size + size_t size() const; + + // Indexing operator + const xpath_node& operator[](size_t index) const; + + // Collection iterators + const_iterator begin() const; + const_iterator end() const; + + // Sort the collection in ascending/descending order by document order + void sort(bool reverse = false); + + // Get first node in the collection by document order + xpath_node first() const; + + // Check if collection is empty + bool empty() const; + + private: + type_t _type; + + xpath_node _storage; + + xpath_node* _begin; + xpath_node* _end; + + void _assign(const_iterator begin, const_iterator end); + }; +#endif + +#ifndef PUGIXML_NO_STL + // Convert wide string to UTF8 + std::basic_string, std::allocator > PUGIXML_FUNCTION as_utf8(const wchar_t* str); + std::basic_string, std::allocator > PUGIXML_FUNCTION as_utf8(const std::basic_string, std::allocator >& str); + + // Convert UTF8 to wide string + std::basic_string, std::allocator > PUGIXML_FUNCTION as_wide(const char* str); + std::basic_string, std::allocator > PUGIXML_FUNCTION as_wide(const std::basic_string, std::allocator >& str); +#endif + + // Memory allocation function interface; returns pointer to allocated memory or NULL on failure + typedef void* (*allocation_function)(size_t size); + + // Memory deallocation function interface + typedef void (*deallocation_function)(void* ptr); + + // Override default memory management functions. All subsequent allocations/deallocations will be performed via supplied functions. + void PUGIXML_FUNCTION set_memory_management_functions(allocation_function allocate, deallocation_function deallocate); + + // Get current memory management functions + allocation_function PUGIXML_FUNCTION get_memory_allocation_function(); + deallocation_function PUGIXML_FUNCTION get_memory_deallocation_function(); +} + +#if !defined(PUGIXML_NO_STL) && (defined(_MSC_VER) || defined(__ICC)) +namespace std +{ + // Workarounds for (non-standard) iterator category detection for older versions (MSVC7/IC8 and earlier) + std::bidirectional_iterator_tag PUGIXML_FUNCTION _Iter_cat(const pugi::xml_node_iterator&); + std::bidirectional_iterator_tag PUGIXML_FUNCTION _Iter_cat(const pugi::xml_attribute_iterator&); + std::forward_iterator_tag PUGIXML_FUNCTION _Iter_cat(const pugi::xml_named_node_iterator&); +} +#endif + +#if !defined(PUGIXML_NO_STL) && defined(__SUNPRO_CC) +namespace std +{ + // Workarounds for (non-standard) iterator category detection + std::bidirectional_iterator_tag PUGIXML_FUNCTION __iterator_category(const pugi::xml_node_iterator&); + std::bidirectional_iterator_tag PUGIXML_FUNCTION __iterator_category(const pugi::xml_attribute_iterator&); + std::forward_iterator_tag PUGIXML_FUNCTION __iterator_category(const pugi::xml_named_node_iterator&); +} +#endif + +#endif + +/** + * Copyright (c) 2006-2012 Arseny Kapoulkine + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ diff --git a/contrib/other-builds/extract-mixed-syntax/tables-core.cpp b/contrib/other-builds/extract-mixed-syntax/tables-core.cpp deleted file mode 100644 index c3c141b7f..000000000 --- a/contrib/other-builds/extract-mixed-syntax/tables-core.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// $Id: tables-core.cpp 3131 2010-04-13 16:29:55Z pjwilliams $ -//#include "beammain.h" -//#include "SafeGetLine.h" -#include "tables-core.h" - -#define TABLE_LINE_MAX_LENGTH 1000 -#define UNKNOWNSTR "UNK" - -// as in beamdecoder/tables.cpp -vector tokenize( const char* input ) { - vector< string > token; - bool betweenWords = true; - int start=0; - int i=0; - for(; input[i] != '\0'; i++) { - bool isSpace = (input[i] == ' ' || input[i] == '\t'); - - if (!isSpace && betweenWords) { - start = i; - betweenWords = false; - } - else if (isSpace && !betweenWords) { - token.push_back( string( input+start, i-start ) ); - betweenWords = true; - } - } - if (!betweenWords) - token.push_back( string( input+start, i-start ) ); - return token; -} - -WORD_ID Vocabulary::storeIfNew( const WORD& word ) { - map::iterator i = lookup.find( word ); - - if( i != lookup.end() ) - return i->second; - - WORD_ID id = vocab.size(); - vocab.push_back( word ); - lookup[ word ] = id; - return id; -} - -WORD_ID Vocabulary::getWordID( const WORD& word ) { - map::iterator i = lookup.find( word ); - if( i == lookup.end() ) - return 0; - return i->second; -} - -PHRASE_ID PhraseTable::storeIfNew( const PHRASE& phrase ) { - map< PHRASE, PHRASE_ID >::iterator i = lookup.find( phrase ); - if( i != lookup.end() ) - return i->second; - - PHRASE_ID id = phraseTable.size(); - phraseTable.push_back( phrase ); - lookup[ phrase ] = id; - return id; -} - -PHRASE_ID PhraseTable::getPhraseID( const PHRASE& phrase ) { - map< PHRASE, PHRASE_ID >::iterator i = lookup.find( phrase ); - if( i == lookup.end() ) - return 0; - return i->second; -} - -void PhraseTable::clear() { - lookup.clear(); - phraseTable.clear(); -} - -void DTable::init() { - for(int i = -10; i<10; i++) - dtable[i] = -abs( i ); -} - -/* -void DTable::load( const string& fileName ) { - ifstream inFile; - inFile.open(fileName.c_str()); - istream *inFileP = &inFile; - - char line[TABLE_LINE_MAX_LENGTH]; - int i=0; - while(true) { - i++; - SAFE_GETLINE((*inFileP), line, TABLE_LINE_MAX_LENGTH, '\n', __FILE__); - if (inFileP->eof()) break; - - vector token = tokenize( line ); - if (token.size() < 2) { - cerr << "line " << i << " in " << fileName << " too short, skipping\n"; - continue; - } - - int d = atoi( token[0].c_str() ); - double prob = log( atof( token[1].c_str() ) ); - dtable[ d ] = prob; - } -} -*/ - -double DTable::get( int distortion ) { - if (dtable.find( distortion ) == dtable.end()) - return log( 0.00001 ); - return dtable[ distortion ]; -} - diff --git a/contrib/other-builds/extract-mixed-syntax/tables-core.h b/contrib/other-builds/extract-mixed-syntax/tables-core.h deleted file mode 100644 index f039ced7e..000000000 --- a/contrib/other-builds/extract-mixed-syntax/tables-core.h +++ /dev/null @@ -1,72 +0,0 @@ -#pragma once -// $Id: tables-core.h 2416 2009-07-30 11:07:38Z hieuhoang1972 $ - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std; - -#define TABLE_LINE_MAX_LENGTH 1000 -#define UNKNOWNSTR "UNK" - -vector tokenize( const char[] ); - -//! delete and remove every element of a collection object such as map, set, list etc -template -void RemoveAllInColl(COLL &coll) -{ - for (typename COLL::const_iterator iter = coll.begin() ; iter != coll.end() ; ++iter) - { - delete (*iter); - } - coll.clear(); -} - -typedef string WORD; -typedef unsigned int WORD_ID; - -class Vocabulary { - public: - map lookup; - vector< WORD > vocab; - WORD_ID storeIfNew( const WORD& ); - WORD_ID getWordID( const WORD& ); - inline WORD &getWord( WORD_ID id ) const { WORD &i = (WORD&) vocab[ id ]; return i; } -}; - -typedef vector< WORD_ID > PHRASE; -typedef unsigned int PHRASE_ID; - -class PhraseTable { - public: - map< PHRASE, PHRASE_ID > lookup; - vector< PHRASE > phraseTable; - PHRASE_ID storeIfNew( const PHRASE& ); - PHRASE_ID getPhraseID( const PHRASE& ); - void clear(); - inline PHRASE &getPhrase( const PHRASE_ID id ) { return phraseTable[ id ]; } -}; - -typedef vector< pair< PHRASE_ID, double > > PHRASEPROBVEC; - -class TTable { - public: - map< PHRASE_ID, vector< pair< PHRASE_ID, double > > > ttable; - map< PHRASE_ID, vector< pair< PHRASE_ID, vector< double > > > > ttableMulti; -}; - -class DTable { - public: - map< int, double > dtable; - void init(); - void load( const string& ); - double get( int ); -}; - - diff --git a/contrib/other-builds/extract-ordering/.cproject b/contrib/other-builds/extract-ordering/.cproject deleted file mode 100644 index 6d03a5a6c..000000000 --- a/contrib/other-builds/extract-ordering/.cproject +++ /dev/null @@ -1,135 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/extract-rules/.cproject b/contrib/other-builds/extract-rules/.cproject deleted file mode 100644 index 0990343b2..000000000 --- a/contrib/other-builds/extract-rules/.cproject +++ /dev/null @@ -1,137 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/extract-rules/.gitignore b/contrib/other-builds/extract-rules/.gitignore deleted file mode 100644 index 98bbc3165..000000000 --- a/contrib/other-builds/extract-rules/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/Debug diff --git a/contrib/other-builds/extract-rules/.project b/contrib/other-builds/extract-rules/.project index 29ffed2a9..d640499a8 100644 --- a/contrib/other-builds/extract-rules/.project +++ b/contrib/other-builds/extract-rules/.project @@ -25,26 +25,6 @@ org.eclipse.cdt.managedbuilder.core.ScannerConfigNature - - ExtractedRule.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/ExtractedRule.h - - - Hole.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/Hole.h - - - HoleCollection.cpp - 1 - PARENT-3-PROJECT_LOC/phrase-extract/HoleCollection.cpp - - - HoleCollection.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/HoleCollection.h - InputFileStream.cpp 1 @@ -65,11 +45,6 @@ 1 PARENT-3-PROJECT_LOC/phrase-extract/OutputFileStream.h - - RuleExtractionOptions.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/RuleExtractionOptions.h - SentenceAlignment.cpp 1 @@ -111,14 +86,9 @@ PARENT-3-PROJECT_LOC/phrase-extract/XmlTree.h - extract-rules-main.cpp + extract-main.cpp 1 - PARENT-3-PROJECT_LOC/phrase-extract/extract-rules-main.cpp - - - gzfilebuf.h - 1 - PARENT-3-PROJECT_LOC/phrase-extract/gzfilebuf.h + PARENT-3-PROJECT_LOC/phrase-extract/extract-main.cpp tables-core.cpp diff --git a/contrib/other-builds/extract/.cproject b/contrib/other-builds/extract/.cproject deleted file mode 100644 index 80655fd18..000000000 --- a/contrib/other-builds/extract/.cproject +++ /dev/null @@ -1,135 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/extractor/.cproject b/contrib/other-builds/extractor/.cproject deleted file mode 100644 index 5f0b24ef0..000000000 --- a/contrib/other-builds/extractor/.cproject +++ /dev/null @@ -1,137 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/extractor/.project b/contrib/other-builds/extractor/.project index e4fe08579..56d560019 100644 --- a/contrib/other-builds/extractor/.project +++ b/contrib/other-builds/extractor/.project @@ -4,6 +4,7 @@ mert_lib + util diff --git a/contrib/other-builds/lm/.cproject b/contrib/other-builds/lm/.cproject deleted file mode 100644 index e3e47fd7e..000000000 --- a/contrib/other-builds/lm/.cproject +++ /dev/null @@ -1,144 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/manual-label/.cproject b/contrib/other-builds/manual-label/.cproject deleted file mode 100644 index 5e9471d42..000000000 --- a/contrib/other-builds/manual-label/.cproject +++ /dev/null @@ -1,126 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/manual-label/DeEn.cpp b/contrib/other-builds/manual-label/DeEn.cpp index 7ef9d495d..ea2934c5a 100644 --- a/contrib/other-builds/manual-label/DeEn.cpp +++ b/contrib/other-builds/manual-label/DeEn.cpp @@ -1,30 +1,12 @@ #include #include "DeEn.h" +#include "Main.h" #include "moses/Util.h" using namespace std; extern bool g_debug; -bool IsA(const Phrase &source, int pos, int offset, int factor, const string &str) -{ - pos += offset; - if (pos >= source.size() || pos < 0) { - return false; - } - - const string &word = source[pos][factor]; - vector soughts = Moses::Tokenize(str, " "); - for (int i = 0; i < soughts.size(); ++i) { - string &sought = soughts[i]; - bool found = (word == sought); - if (found) { - return true; - } - } - return false; -} - bool Contains(const Phrase &source, int start, int end, int factor, const string &str) { for (int pos = start; pos <= end; ++pos) { @@ -38,8 +20,6 @@ bool Contains(const Phrase &source, int start, int end, int factor, const string void LabelDeEn(const Phrase &source, ostream &out) { - typedef pair Range; - typedef list Ranges; Ranges ranges; // find ranges to label @@ -48,39 +28,19 @@ void LabelDeEn(const Phrase &source, ostream &out) if (IsA(source, start, -1, 1, "VAFIN") && IsA(source, end, +1, 1, "VVINF VVPP") && !Contains(source, start, end, 1, "VAFIN VVINF VVPP VVFIN")) { - Range range(start, end); + Range range(start, end, "reorder-label"); ranges.push_back(range); } else if ((start == 0 || IsA(source, start, -1, 1, "$,")) && IsA(source, end, +1, 0, "zu") && IsA(source, end, +2, 1, "VVINF") && !Contains(source, start, end, 1, "$,")) { - Range range(start, end); + Range range(start, end, "reorder-label"); ranges.push_back(range); } } } - // output sentence, with labels - for (int pos = 0; pos < source.size(); ++pos) { - // output beginning of label - for (Ranges::const_iterator iter = ranges.begin(); iter != ranges.end(); ++iter) { - const Range &range = *iter; - if (range.first == pos) { - out << " "; - } - } - - const Word &word = source[pos]; - out << word[0] << " "; - - for (Ranges::const_iterator iter = ranges.begin(); iter != ranges.end(); ++iter) { - const Range &range = *iter; - if (range.second == pos) { - out << " "; - } - } - } - out << endl; - + OutputWithLabels(source, ranges, out); } + diff --git a/contrib/other-builds/manual-label/DeEn.h b/contrib/other-builds/manual-label/DeEn.h index 999c2dfbd..c24ce0079 100644 --- a/contrib/other-builds/manual-label/DeEn.h +++ b/contrib/other-builds/manual-label/DeEn.h @@ -1,10 +1,5 @@ #pragma once -#include -#include -#include - -typedef std::vector Word; -typedef std::vector Phrase; +#include "Main.h" void LabelDeEn(const Phrase &source, std::ostream &out); diff --git a/contrib/other-builds/manual-label/EnOpenNLPChunker.cpp b/contrib/other-builds/manual-label/EnOpenNLPChunker.cpp new file mode 100644 index 000000000..67c2e9d84 --- /dev/null +++ b/contrib/other-builds/manual-label/EnOpenNLPChunker.cpp @@ -0,0 +1,201 @@ +/* + * EnApacheChunker.cpp + * + * Created on: 28 Feb 2014 + * Author: hieu + */ +#include +#include +#include +#include +#include "EnOpenNLPChunker.h" +#include "moses/Util.h" + +using namespace std; + +EnOpenNLPChunker::EnOpenNLPChunker(const std::string &openNLPPath) +:m_openNLPPath(openNLPPath) +{ + // TODO Auto-generated constructor stub + +} + +EnOpenNLPChunker::~EnOpenNLPChunker() { + // TODO Auto-generated destructor stub +} + +void EnOpenNLPChunker::Process(std::istream &in, std::ostream &out, const vector &filterList) +{ + // read all input to a temp file + char *ptr = tmpnam(NULL); + string inStr(ptr); + ofstream inFile(ptr); + + string line; + while (getline(in, line)) { + Unescape(line); + inFile << line << endl; + } + inFile.close(); + + ptr = tmpnam(NULL); + string outStr(ptr); + + // execute chunker + string cmd = "cat " + inStr + " | " + + m_openNLPPath + "/bin/opennlp POSTagger " + + m_openNLPPath + "/models/en-pos-maxent.bin | " + + m_openNLPPath + "/bin/opennlp ChunkerME " + + m_openNLPPath + "/models/en-chunker.bin > " + + outStr; + //g << "Executing:" << cmd << endl; + int ret = system(cmd.c_str()); + + // read result of chunker and output as Moses xml trees + ifstream outFile(outStr.c_str()); + + size_t lineNum = 0; + while (getline(outFile, line)) { + //cerr << line << endl; + MosesReformat(line, out, filterList); + out << endl; + ++lineNum; + } + outFile.close(); + + // clean up temporary files + remove(inStr.c_str()); + remove(outStr.c_str()); +} + +void EnOpenNLPChunker::MosesReformat(const string &line, std::ostream &out, const vector &filterList) +{ + //cerr << "REFORMATING:" << line << endl; + bool inLabel = false; + vector toks; + Moses::Tokenize(toks, line); + for (size_t i = 0; i < toks.size(); ++i) { + const string &tok = toks[i]; + + if (tok.substr(0, 1) == "[" && tok.substr(1,1) != "_") { + // start of chunk + string label = tok.substr(1); + if (UseLabel(label, filterList)) { + out << ""; + inLabel = true; + } + } + else if (tok.substr(tok.size()-1, 1) == "]") { + // end of chunk + if (tok.size() > 1) { + if (tok.substr(1,1) == "_") { + // just a word that happens to be ] + vector factors; + Moses::Tokenize(factors, tok, "_"); + assert(factors.size() == 2); + + Escape(factors[0]); + out << factors[0] << " "; + } + else { + // a word and end of tree + string word = tok.substr(0, tok.size()-1); + + vector factors; + Moses::Tokenize(factors, word, "_"); + assert(factors.size() == 2); + + Escape(factors[0]); + out << factors[0] << " "; + } + + if (inLabel) { + out << " "; + inLabel = false; + } + } + else { + if (inLabel) { + out << " "; + inLabel = false; + } + } + + } + else { + // lexical item + vector factors; + Moses::Tokenize(factors, tok, "_"); + if (factors.size() == 2) { + Escape(factors[0]); + out << factors[0] << " "; + } + else if (factors.size() == 1) { + // word is _ + assert(tok.substr(0, 2) == "__"); + out << "_ "; + } + else { + throw "Unknown format:" + tok; + } + } + } +} + +std::string +replaceAll( std::string const& original, + std::string const& before, + std::string const& after ) +{ + std::string retval; + std::string::const_iterator end = original.end(); + std::string::const_iterator current = original.begin(); + std::string::const_iterator next = + std::search( current, end, before.begin(), before.end() ); + while ( next != end ) { + retval.append( current, next ); + retval.append( after ); + current = next + before.size(); + next = std::search( current, end, before.begin(), before.end() ); + } + retval.append( current, next ); + return retval; +} + +void EnOpenNLPChunker::Escape(string &line) +{ + line = replaceAll(line, "&", "&"); + line = replaceAll(line, "|", "|"); + line = replaceAll(line, "<", "<"); + line = replaceAll(line, ">", ">"); + line = replaceAll(line, "'", "'"); + line = replaceAll(line, "\"", """); + line = replaceAll(line, "[", "["); + line = replaceAll(line, "]", "]"); +} + +void EnOpenNLPChunker::Unescape(string &line) +{ + line = replaceAll(line, "|", "|"); + line = replaceAll(line, "<", "<"); + line = replaceAll(line, ">", ">"); + line = replaceAll(line, """, "\""); + line = replaceAll(line, "'", "'"); + line = replaceAll(line, "[", "["); + line = replaceAll(line, "]", "]"); + line = replaceAll(line, "&", "&"); +} + +bool EnOpenNLPChunker::UseLabel(const std::string &label, const std::vector &filterList) const +{ + if (filterList.size() == 0) { + return true; + } + + for (size_t i = 0; i < filterList.size(); ++i) { + if (label == filterList[i]) { + return true; + } + } + return false; +} diff --git a/contrib/other-builds/manual-label/EnOpenNLPChunker.h b/contrib/other-builds/manual-label/EnOpenNLPChunker.h new file mode 100644 index 000000000..df9f90e42 --- /dev/null +++ b/contrib/other-builds/manual-label/EnOpenNLPChunker.h @@ -0,0 +1,29 @@ +/* + * EnApacheChunker.h + * + * Created on: 28 Feb 2014 + * Author: hieu + */ + +#pragma once + +#include +#include +#include + +class EnOpenNLPChunker { +public: + EnOpenNLPChunker(const std::string &openNLPPath); + virtual ~EnOpenNLPChunker(); + void Process(std::istream &in, std::ostream &out, const std::vector &filterList); +protected: + const std::string m_openNLPPath; + + void Escape(std::string &line); + void Unescape(std::string &line); + + void MosesReformat(const std::string &line, std::ostream &out, const std::vector &filterList); + + bool UseLabel(const std::string &label, const std::vector &filterList) const; +}; + diff --git a/contrib/other-builds/manual-label/EnPhrasalVerb.cpp b/contrib/other-builds/manual-label/EnPhrasalVerb.cpp new file mode 100644 index 000000000..4bee9b941 --- /dev/null +++ b/contrib/other-builds/manual-label/EnPhrasalVerb.cpp @@ -0,0 +1,226 @@ +#include +#include +#include +#include +#include "EnPhrasalVerb.h" +#include "moses/Util.h" + +using namespace std; + +void EnPhrasalVerb(const Phrase &source, int revision, ostream &out) +{ + Ranges ranges; + + // find ranges to label + for (int start = 0; start < source.size(); ++start) { + size_t end = std::numeric_limits::max(); + + if (IsA(source, start, 0, 0, "ask asked asking")) { + end = Found(source, start, 0, "out"); + } + else if (IsA(source, start, 0, 0, "back backed backing")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "blow blown blew")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "break broke broken")) { + end = Found(source, start, 0, "down up in"); + } + else if (IsA(source, start, 0, 0, "bring brought bringing")) { + end = Found(source, start, 0, "down up in"); + } + else if (IsA(source, start, 0, 0, "call called calling")) { + end = Found(source, start, 0, "back up off"); + } + else if (IsA(source, start, 0, 0, "check checked checking")) { + end = Found(source, start, 0, "out in"); + } + else if (IsA(source, start, 0, 0, "cheer cheered cheering")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "clean cleaned cleaning")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "cross crossed crossing")) { + end = Found(source, start, 0, "out"); + } + else if (IsA(source, start, 0, 0, "cut cutting")) { + end = Found(source, start, 0, "down off out"); + } + else if (IsA(source, start, 0, 0, "do did done")) { + end = Found(source, start, 0, "over up"); + } + else if (IsA(source, start, 0, 0, "drop dropped dropping")) { + end = Found(source, start, 0, "off"); + } + else if (IsA(source, start, 0, 0, "figure figured figuring")) { + end = Found(source, start, 0, "out"); + } + else if (IsA(source, start, 0, 0, "fill filled filling")) { + end = Found(source, start, 0, "in out up"); + } + else if (IsA(source, start, 0, 0, "find found finding")) { + end = Found(source, start, 0, "out"); + } + else if (IsA(source, start, 0, 0, "get got getting gotten")) { + end = Found(source, start, 0, "across over back"); + } + else if (IsA(source, start, 0, 0, "give given gave giving")) { + end = Found(source, start, 0, "away back out up"); + } + else if (IsA(source, start, 0, 0, "hand handed handing")) { + end = Found(source, start, 0, "down in over"); + } + else if (IsA(source, start, 0, 0, "hold held holding")) { + end = Found(source, start, 0, "back up"); + } + else if (IsA(source, start, 0, 0, "keep kept keeping")) { + end = Found(source, start, 0, "from up"); + } + else if (IsA(source, start, 0, 0, "let letting")) { + end = Found(source, start, 0, "down in"); + } + else if (IsA(source, start, 0, 0, "look looked looking")) { + end = Found(source, start, 0, "over up"); + } + else if (IsA(source, start, 0, 0, "make made making")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "mix mixed mixing")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "pass passed passing")) { + end = Found(source, start, 0, "out up"); + } + else if (IsA(source, start, 0, 0, "pay payed paying")) { + end = Found(source, start, 0, "back"); + } + else if (IsA(source, start, 0, 0, "pick picked picking")) { + end = Found(source, start, 0, "out"); + } + else if (IsA(source, start, 0, 0, "point pointed pointing")) { + end = Found(source, start, 0, "out"); + } + else if (IsA(source, start, 0, 0, "put putting")) { + end = Found(source, start, 0, "down off out together on"); + } + else if (IsA(source, start, 0, 0, "send sending")) { + end = Found(source, start, 0, "back"); + } + else if (IsA(source, start, 0, 0, "set setting")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "sort sorted sorting")) { + end = Found(source, start, 0, "out"); + } + else if (IsA(source, start, 0, 0, "switch switched switching")) { + end = Found(source, start, 0, "off on"); + } + else if (IsA(source, start, 0, 0, "take took taking")) { + end = Found(source, start, 0, "apart back off out"); + } + else if (IsA(source, start, 0, 0, "tear torn tearing")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "think thought thinking")) { + end = Found(source, start, 0, "over"); + } + else if (IsA(source, start, 0, 0, "thrown threw thrown throwing")) { + end = Found(source, start, 0, "away"); + } + else if (IsA(source, start, 0, 0, "turn turned turning")) { + end = Found(source, start, 0, "down off on"); + } + else if (IsA(source, start, 0, 0, "try tried trying")) { + end = Found(source, start, 0, "on out"); + } + else if (IsA(source, start, 0, 0, "use used using")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "warm warmed warming")) { + end = Found(source, start, 0, "up"); + } + else if (IsA(source, start, 0, 0, "work worked working")) { + end = Found(source, start, 0, "out"); + } + + // found range to label + if (end != std::numeric_limits::max() && + end > start + 1) { + bool add = true; + if (revision == 1 && Exist(source, + start + 1, + end - 1, + 1, + "VB VBD VBG VBN VBP VBZ")) { + // there's a verb in between + add = false; + } + + if (add) { + Range range(start + 1, end - 1, "reorder-label"); + ranges.push_back(range); + } + } + } + + OutputWithLabels(source, ranges, out); +} + +bool Exist(const Phrase &source, int start, int end, int factor, const std::string &str) +{ + vector soughts = Moses::Tokenize(str, " "); + for (size_t i = start; i <= end; ++i) { + const Word &word = source[i]; + bool found = Found(word, factor, soughts); + if (found) { + return true; + } + } + + return false; +} + +size_t Found(const Phrase &source, int pos, int factor, const std::string &str) +{ + const size_t MAX_RANGE = 10; + + vector soughts = Moses::Tokenize(str, " "); + vector puncts = Moses::Tokenize(". : , ;", " "); + + + size_t maxEnd = std::min(source.size(), (size_t) pos + MAX_RANGE); + for (size_t i = pos + 1; i < maxEnd; ++i) { + const Word &word = source[i]; + bool found; + + found = Found(word, factor, puncts); + if (found) { + return std::numeric_limits::max(); + } + + found = Found(word, factor, soughts); + if (found) { + return i; + } + } + + return std::numeric_limits::max(); +} + + +bool Found(const Word &word, int factor, const vector &soughts) +{ + const string &element = word[factor]; + for (size_t i = 0; i < soughts.size(); ++i) { + const string &sought = soughts[i]; + bool found = (element == sought); + if (found) { + return true; + } + } + return false; +} + + diff --git a/contrib/other-builds/manual-label/EnPhrasalVerb.h b/contrib/other-builds/manual-label/EnPhrasalVerb.h new file mode 100644 index 000000000..4cb5f7348 --- /dev/null +++ b/contrib/other-builds/manual-label/EnPhrasalVerb.h @@ -0,0 +1,11 @@ +#pragma once + +#include "Main.h" + +// roll your own identification of phrasal verbs +void EnPhrasalVerb(const Phrase &source, int revision, std::ostream &out); + +bool Exist(const Phrase &source, int start, int end, int factor, const std::string &str); +size_t Found(const Phrase &source, int pos, int factor, const std::string &str); +bool Found(const Word &word, int factor, const std::vector &soughts); + diff --git a/contrib/other-builds/manual-label/LabelByInitialLetter.cpp b/contrib/other-builds/manual-label/LabelByInitialLetter.cpp new file mode 100644 index 000000000..e4136a7ea --- /dev/null +++ b/contrib/other-builds/manual-label/LabelByInitialLetter.cpp @@ -0,0 +1,29 @@ +#include "LabelByInitialLetter.h" +#include "Main.h" + +using namespace std; + +void LabelByInitialLetter(const Phrase &source, std::ostream &out) +{ + Ranges ranges; + + for (int start = 0; start < source.size(); ++start) { + const string &startWord = source[start][0]; + string startChar = startWord.substr(0,1); + + for (int end = start + 1; end < source.size(); ++end) { + const string &endWord = source[end][0]; + string endChar = endWord.substr(0,1); + + if (startChar == endChar) { + Range range(start, end, startChar + "-label"); + ranges.push_back(range); + } + } + } + + OutputWithLabels(source, ranges, out); + +} + + diff --git a/contrib/other-builds/manual-label/LabelByInitialLetter.h b/contrib/other-builds/manual-label/LabelByInitialLetter.h new file mode 100644 index 000000000..ba8d34c19 --- /dev/null +++ b/contrib/other-builds/manual-label/LabelByInitialLetter.h @@ -0,0 +1,6 @@ +#pragma once + +#include "Main.h" + +void LabelByInitialLetter(const Phrase &source, std::ostream &out); + diff --git a/contrib/other-builds/manual-label/Main.cpp b/contrib/other-builds/manual-label/Main.cpp new file mode 100644 index 000000000..896f70590 --- /dev/null +++ b/contrib/other-builds/manual-label/Main.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include "moses/Util.h" +#include "Main.h" +#include "DeEn.h" +#include "EnPhrasalVerb.h" +#include "EnOpenNLPChunker.h" +#include "LabelByInitialLetter.h" + +using namespace std; + +bool g_debug = false; + +Phrase Tokenize(const string &line); + +int main(int argc, char** argv) +{ + cerr << "Starting" << endl; + + namespace po = boost::program_options; + po::options_description desc("Options"); + desc.add_options() + ("help", "Print help messages") + + ("input,i", po::value(), "Input file. Otherwise it will read from standard in") + ("output,o", po::value(), "Output file. Otherwise it will print from standard out") + + ("source-language,s", po::value()->required(), "Source Language") + ("target-language,t", po::value()->required(), "Target Language") + ("revision,r", po::value()->default_value(0), "Revision") + ("filter", po::value(), "Only use labels from this comma-separated list") + + ("opennlp", po::value()->default_value(""), "Path to Apache OpenNLP toolkit") + + ; + + po::variables_map vm; + try + { + po::store(po::parse_command_line(argc, argv, desc), + vm); // can throw + + /** --help option + */ + if ( vm.count("help") ) + { + std::cout << "Basic Command Line Parameter App" << std::endl + << desc << std::endl; + return EXIT_SUCCESS; + } + + po::notify(vm); // throws on error, so do after help in case + // there are any problems + } + catch(po::error& e) + { + std::cerr << "ERROR: " << e.what() << std::endl << std::endl; + std::cerr << desc << std::endl; + return EXIT_FAILURE; + } + + istream *inStrm = &cin; + if (vm.count("input")) { + string inStr = vm["input"].as(); + cerr << "inStr=" << inStr << endl; + ifstream *inFile = new ifstream(inStr.c_str()); + inStrm = inFile; + } + + ostream *outStrm = &cout; + if (vm.count("output")) { + string outStr = vm["output"].as(); + cerr << "outStr=" << outStr << endl; + ostream *outFile = new ofstream(outStr.c_str()); + outStrm = outFile; + } + + vector filterList; + if (vm.count("filter")) { + string filter = vm["filter"].as(); + Moses::Tokenize(filterList, filter, ","); + } + + string sourceLang = vm["source-language"].as(); + string targetLang = vm["target-language"].as(); + int revision = vm["revision"].as(); + + cerr << sourceLang << " " << targetLang << " " << revision << endl; + + if (sourceLang == "en" && revision == 2) { + if (vm.count("opennlp") == 0) { + throw "Need path to openNLP toolkit"; + } + + string openNLPPath = vm["opennlp"].as(); + EnOpenNLPChunker chunker(openNLPPath); + chunker.Process(*inStrm, *outStrm, filterList); + } + else { + // process line-by-line + string line; + size_t lineNum = 1; + + while (getline(*inStrm, line)) { + //cerr << lineNum << ":" << line << endl; + if (lineNum % 1000 == 0) { + cerr << lineNum << " "; + } + + Phrase source = Tokenize(line); + + if (revision == 600 ) { + LabelByInitialLetter(source, *outStrm); + } + else if (sourceLang == "de" && targetLang == "en") { + LabelDeEn(source, *outStrm); + } + else if (sourceLang == "en") { + if (revision == 0 || revision == 1) { + EnPhrasalVerb(source, revision, *outStrm); + } + else if (revision == 2) { + string openNLPPath = vm["opennlp-path"].as(); + EnOpenNLPChunker chunker(openNLPPath); + } + } + + ++lineNum; + } + } + + + cerr << "Finished" << endl; + return EXIT_SUCCESS; +} + +Phrase Tokenize(const string &line) +{ + Phrase ret; + + vector toks = Moses::Tokenize(line); + for (size_t i = 0; i < toks.size(); ++i) { + Word word = Moses::Tokenize(toks[i], "|"); + ret.push_back(word); + } + + return ret; +} + +bool IsA(const Phrase &source, int pos, int offset, int factor, const string &str) +{ + pos += offset; + if (pos >= source.size() || pos < 0) { + return false; + } + + const string &word = source[pos][factor]; + vector soughts = Moses::Tokenize(str, " "); + for (int i = 0; i < soughts.size(); ++i) { + string &sought = soughts[i]; + bool found = (word == sought); + if (found) { + return true; + } + } + return false; +} + + +void OutputWithLabels(const Phrase &source, const Ranges ranges, ostream &out) +{ + // output sentence, with labels + for (int pos = 0; pos < source.size(); ++pos) { + // output beginning of label + for (Ranges::const_iterator iter = ranges.begin(); iter != ranges.end(); ++iter) { + const Range &range = *iter; + if (range.range.first == pos) { + out << " "; + } + } + + const Word &word = source[pos]; + out << word[0] << " "; + + for (Ranges::const_iterator iter = ranges.begin(); iter != ranges.end(); ++iter) { + const Range &range = *iter; + if (range.range.second == pos) { + out << " "; + } + } + } + out << endl; + +} diff --git a/contrib/other-builds/manual-label/Main.h b/contrib/other-builds/manual-label/Main.h new file mode 100644 index 000000000..036da0d45 --- /dev/null +++ b/contrib/other-builds/manual-label/Main.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include +#include + +typedef std::vector Word; +typedef std::vector Phrase; + +struct Range +{ + Range(int start,int end, const std::string &l) + :range(start, end) + ,label(l) + {} + + std::pair range; + std::string label; +}; + +typedef std::list Ranges; + +bool IsA(const Phrase &source, int pos, int offset, int factor, const std::string &str); +void OutputWithLabels(const Phrase &source, const Ranges ranges, std::ostream &out); + + diff --git a/contrib/other-builds/manual-label/Makefile b/contrib/other-builds/manual-label/Makefile index 60ce975cd..f24d69dc7 100644 --- a/contrib/other-builds/manual-label/Makefile +++ b/contrib/other-builds/manual-label/Makefile @@ -4,10 +4,11 @@ clean: rm -f *.o manual-label .cpp.o: - g++ -I../../../ -O6 -g -c $< + g++ -I../../../boost/include -I../../../ -O3 -g -c $< -manual-label: DeEn.o manual-label.o +OBJECTS = DeEn.o EnOpenNLPChunker.o EnPhrasalVerb.o Main.o LabelByInitialLetter.o - g++ DeEn.o manual-label.o -lz -lboost_program_options-mt -o manual-label +manual-label: $(OBJECTS) + g++ $(OBJECTS) -L../../../boost/lib64 -lz -lboost_program_options-mt -o manual-label diff --git a/contrib/other-builds/manual-label/manual-label.cpp b/contrib/other-builds/manual-label/manual-label.cpp deleted file mode 100644 index 4500d2c84..000000000 --- a/contrib/other-builds/manual-label/manual-label.cpp +++ /dev/null @@ -1,88 +0,0 @@ -#include -#include -#include -#include "moses/Util.h" -#include "DeEn.h" - -using namespace std; - -bool g_debug = false; - -Phrase Tokenize(const string &line); - -int main(int argc, char** argv) -{ - cerr << "Starting" << endl; - - namespace po = boost::program_options; - po::options_description desc("Options"); - desc.add_options() - ("help", "Print help messages") - ("add", "additional options") - ("source-language,s", po::value()->required(), "Source Language") - ("target-language,t", po::value()->required(), "Target Language"); - - po::variables_map vm; - try - { - po::store(po::parse_command_line(argc, argv, desc), - vm); // can throw - - /** --help option - */ - if ( vm.count("help") ) - { - std::cout << "Basic Command Line Parameter App" << std::endl - << desc << std::endl; - return EXIT_SUCCESS; - } - - po::notify(vm); // throws on error, so do after help in case - // there are any problems - } - catch(po::error& e) - { - std::cerr << "ERROR: " << e.what() << std::endl << std::endl; - std::cerr << desc << std::endl; - return EXIT_FAILURE; - } - - string sourceLang = vm["source-language"].as(); - string targetLang = vm["target-language"].as(); - cerr << sourceLang << " " << targetLang << endl; - - string line; - size_t lineNum = 1; - - while (getline(cin, line)) { - //cerr << lineNum << ":" << line << endl; - if (lineNum % 1000 == 0) { - cerr << lineNum << " "; - } - - Phrase source = Tokenize(line); - - LabelDeEn(source, cout); - - ++lineNum; - } - - - - cerr << "Finished" << endl; - return EXIT_SUCCESS; -} - -Phrase Tokenize(const string &line) -{ - Phrase ret; - - vector toks = Moses::Tokenize(line); - for (size_t i = 0; i < toks.size(); ++i) { - Word word = Moses::Tokenize(toks[i], "|"); - ret.push_back(word); - } - - return ret; -} - diff --git a/contrib/other-builds/mert_lib/.cproject b/contrib/other-builds/mert_lib/.cproject deleted file mode 100644 index c53700bac..000000000 --- a/contrib/other-builds/mert_lib/.cproject +++ /dev/null @@ -1,133 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/mira/.cproject b/contrib/other-builds/mira/.cproject deleted file mode 100644 index 72f66b5fb..000000000 --- a/contrib/other-builds/mira/.cproject +++ /dev/null @@ -1,176 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/mira/.project b/contrib/other-builds/mira/.project deleted file mode 100644 index 03838731f..000000000 --- a/contrib/other-builds/mira/.project +++ /dev/null @@ -1,81 +0,0 @@ - - - mira - - - mert_lib - moses - - - - org.eclipse.cdt.managedbuilder.core.genmakebuilder - clean,full,incremental, - - - - - org.eclipse.cdt.managedbuilder.core.ScannerConfigBuilder - full,incremental, - - - - - - org.eclipse.cdt.core.cnature - org.eclipse.cdt.core.ccnature - org.eclipse.cdt.managedbuilder.core.managedBuildNature - org.eclipse.cdt.managedbuilder.core.ScannerConfigNature - - - - Decoder.cpp - 1 - PARENT-3-PROJECT_LOC/mira/Decoder.cpp - - - Decoder.h - 1 - PARENT-3-PROJECT_LOC/mira/Decoder.h - - - Hildreth.cpp - 1 - PARENT-3-PROJECT_LOC/mira/Hildreth.cpp - - - Hildreth.h - 1 - PARENT-3-PROJECT_LOC/mira/Hildreth.h - - - HypothesisQueue.cpp - 1 - PARENT-3-PROJECT_LOC/mira/HypothesisQueue.cpp - - - HypothesisQueue.h - 1 - PARENT-3-PROJECT_LOC/mira/HypothesisQueue.h - - - Main.cpp - 1 - PARENT-3-PROJECT_LOC/mira/Main.cpp - - - Main.h - 1 - PARENT-3-PROJECT_LOC/mira/Main.h - - - MiraOptimiser.cpp - 1 - PARENT-3-PROJECT_LOC/mira/MiraOptimiser.cpp - - - Perceptron.cpp - 1 - PARENT-3-PROJECT_LOC/mira/Perceptron.cpp - - - diff --git a/contrib/other-builds/moses-chart-cmd/.cproject b/contrib/other-builds/moses-chart-cmd/.cproject deleted file mode 100644 index 86dfbac5b..000000000 --- a/contrib/other-builds/moses-chart-cmd/.cproject +++ /dev/null @@ -1,182 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/moses-cmd/.cproject b/contrib/other-builds/moses-cmd/.cproject deleted file mode 100644 index 828b71395..000000000 --- a/contrib/other-builds/moses-cmd/.cproject +++ /dev/null @@ -1,182 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/moses/.cproject b/contrib/other-builds/moses/.cproject deleted file mode 100644 index 0c408b368..000000000 --- a/contrib/other-builds/moses/.cproject +++ /dev/null @@ -1,184 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/moses/.project b/contrib/other-builds/moses/.project index d7580e0de..1334febc0 100644 --- a/contrib/other-builds/moses/.project +++ b/contrib/other-builds/moses/.project @@ -601,16 +601,6 @@ 1 PARENT-3-PROJECT_LOC/moses/ReorderingConstraint.h - - ReorderingStack.cpp - 1 - PARENT-3-PROJECT_LOC/moses/ReorderingStack.cpp - - - ReorderingStack.h - 1 - PARENT-3-PROJECT_LOC/moses/ReorderingStack.h - RuleCube.cpp 1 @@ -1301,6 +1291,16 @@ 1 PARENT-3-PROJECT_LOC/moses/FF/SoftMatchingFeature.h + + FF/SourceGHKMTreeInputMatchFeature.cpp + 1 + PARENT-3-PROJECT_LOC/moses/FF/SourceGHKMTreeInputMatchFeature.cpp + + + FF/SourceGHKMTreeInputMatchFeature.h + 1 + PARENT-3-PROJECT_LOC/moses/FF/SourceGHKMTreeInputMatchFeature.h + FF/SourceWordDeletionFeature.cpp 1 @@ -1311,6 +1311,26 @@ 1 PARENT-3-PROJECT_LOC/moses/FF/SourceWordDeletionFeature.h + + FF/SpanLength.cpp + 1 + PARENT-3-PROJECT_LOC/moses/FF/SpanLength.cpp + + + FF/SpanLength.h + 1 + PARENT-3-PROJECT_LOC/moses/FF/SpanLength.h + + + FF/SparseHieroReorderingFeature.cpp + 1 + PARENT-3-PROJECT_LOC/moses/FF/SparseHieroReorderingFeature.cpp + + + FF/SparseHieroReorderingFeature.h + 1 + PARENT-3-PROJECT_LOC/moses/FF/SparseHieroReorderingFeature.h + FF/StatefulFeatureFunction.cpp 1 @@ -1331,6 +1351,16 @@ 1 PARENT-3-PROJECT_LOC/moses/FF/StatelessFeatureFunction.h + + FF/SyntaxRHS.cpp + 1 + PARENT-3-PROJECT_LOC/moses/FF/SyntaxRHS.cpp + + + FF/SyntaxRHS.h + 1 + PARENT-3-PROJECT_LOC/moses/FF/SyntaxRHS.h + FF/TargetBigramFeature.cpp 1 @@ -1606,6 +1636,21 @@ 1 PARENT-3-PROJECT_LOC/moses/LM/backward.arpa + + LM/oxlm + 2 + virtual:/virtual + + + PP/CountsPhraseProperty.cpp + 1 + PARENT-3-PROJECT_LOC/moses/PP/CountsPhraseProperty.cpp + + + PP/CountsPhraseProperty.h + 1 + PARENT-3-PROJECT_LOC/moses/PP/CountsPhraseProperty.h + PP/Factory.cpp 1 @@ -1616,11 +1661,46 @@ 1 PARENT-3-PROJECT_LOC/moses/PP/Factory.h + + PP/NonTermContextProperty.cpp + 1 + PARENT-3-PROJECT_LOC/moses/PP/NonTermContextProperty.cpp + + + PP/NonTermContextProperty.h + 1 + PARENT-3-PROJECT_LOC/moses/PP/NonTermContextProperty.h + + + PP/PhraseProperty.cpp + 1 + PARENT-3-PROJECT_LOC/moses/PP/PhraseProperty.cpp + PP/PhraseProperty.h 1 PARENT-3-PROJECT_LOC/moses/PP/PhraseProperty.h + + PP/SourceLabelsPhraseProperty.cpp + 1 + PARENT-3-PROJECT_LOC/moses/PP/SourceLabelsPhraseProperty.cpp + + + PP/SourceLabelsPhraseProperty.h + 1 + PARENT-3-PROJECT_LOC/moses/PP/SourceLabelsPhraseProperty.h + + + PP/SpanLengthPhraseProperty.cpp + 1 + PARENT-3-PROJECT_LOC/moses/PP/SpanLengthPhraseProperty.cpp + + + PP/SpanLengthPhraseProperty.h + 1 + PARENT-3-PROJECT_LOC/moses/PP/SpanLengthPhraseProperty.h + PP/TreeStructurePhraseProperty.h 1 @@ -1761,6 +1841,11 @@ 1 PARENT-3-PROJECT_LOC/moses/TranslationModel/PhraseDictionaryTreeAdaptor.h + + TranslationModel/ProbingPT + 2 + virtual:/virtual + TranslationModel/RuleTable 2 @@ -1846,6 +1931,26 @@ 1 PARENT-3-PROJECT_LOC/moses/FF/LexicalReordering/LexicalReorderingTable.h + + FF/LexicalReordering/ReorderingStack.cpp + 1 + PARENT-3-PROJECT_LOC/moses/FF/LexicalReordering/ReorderingStack.cpp + + + FF/LexicalReordering/ReorderingStack.h + 1 + PARENT-3-PROJECT_LOC/moses/FF/LexicalReordering/ReorderingStack.h + + + FF/LexicalReordering/SparseReordering.cpp + 1 + PARENT-3-PROJECT_LOC/moses/FF/LexicalReordering/SparseReordering.cpp + + + FF/LexicalReordering/SparseReordering.h + 1 + PARENT-3-PROJECT_LOC/moses/FF/LexicalReordering/SparseReordering.h + FF/OSM-Feature/OpSequenceModel.cpp 1 @@ -1866,6 +1971,26 @@ 1 PARENT-3-PROJECT_LOC/moses/FF/OSM-Feature/osmHyp.h + + LM/oxlm/LBLLM.cpp + 1 + PARENT-3-PROJECT_LOC/moses/LM/oxlm/LBLLM.cpp + + + LM/oxlm/LBLLM.h + 1 + PARENT-3-PROJECT_LOC/moses/LM/oxlm/LBLLM.h + + + LM/oxlm/Mapper.cpp + 1 + PARENT-3-PROJECT_LOC/moses/LM/oxlm/Mapper.cpp + + + LM/oxlm/Mapper.h + 1 + PARENT-3-PROJECT_LOC/moses/LM/oxlm/Mapper.h + TranslationModel/CYKPlusParser/ChartRuleLookupManagerCYKPlus.cpp 1 @@ -2181,6 +2306,91 @@ 1 PARENT-3-PROJECT_LOC/moses/TranslationModel/DynSAInclude/vocab.h + + TranslationModel/ProbingPT/Jamfile + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/Jamfile + + + TranslationModel/ProbingPT/ProbingPT.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/ProbingPT.cpp + + + TranslationModel/ProbingPT/ProbingPT.h + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/ProbingPT.h + + + TranslationModel/ProbingPT/hash.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/hash.cpp + + + TranslationModel/ProbingPT/hash.hh + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/hash.hh + + + TranslationModel/ProbingPT/huffmanish.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/huffmanish.cpp + + + TranslationModel/ProbingPT/huffmanish.hh + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/huffmanish.hh + + + TranslationModel/ProbingPT/line_splitter.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/line_splitter.cpp + + + TranslationModel/ProbingPT/line_splitter.hh + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/line_splitter.hh + + + TranslationModel/ProbingPT/probing_hash_utils.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/probing_hash_utils.cpp + + + TranslationModel/ProbingPT/probing_hash_utils.hh + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/probing_hash_utils.hh + + + TranslationModel/ProbingPT/quering.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/quering.cpp + + + TranslationModel/ProbingPT/quering.hh + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/quering.hh + + + TranslationModel/ProbingPT/storing.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/storing.cpp + + + TranslationModel/ProbingPT/storing.hh + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/storing.hh + + + TranslationModel/ProbingPT/vocabid.cpp + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/vocabid.cpp + + + TranslationModel/ProbingPT/vocabid.hh + 1 + PARENT-3-PROJECT_LOC/moses/TranslationModel/ProbingPT/vocabid.hh + TranslationModel/RuleTable/Loader.h 1 diff --git a/contrib/other-builds/score/.cproject b/contrib/other-builds/score/.cproject deleted file mode 100644 index f51f35ef5..000000000 --- a/contrib/other-builds/score/.cproject +++ /dev/null @@ -1,133 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/score/.project b/contrib/other-builds/score/.project index 05564d0f9..10e713124 100644 --- a/contrib/other-builds/score/.project +++ b/contrib/other-builds/score/.project @@ -87,16 +87,6 @@ 1 PARENT-3-PROJECT_LOC/phrase-extract/ScoreFeature.h - - exception.cc - 1 - PARENT-3-PROJECT_LOC/util/exception.cc - - - exception.hh - 1 - PARENT-3-PROJECT_LOC/util/exception.hh - score-main.cpp 1 diff --git a/contrib/other-builds/search/.cproject b/contrib/other-builds/search/.cproject deleted file mode 100644 index 44ae0e94e..000000000 --- a/contrib/other-builds/search/.cproject +++ /dev/null @@ -1,139 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/contrib/other-builds/util/.cproject b/contrib/other-builds/util/.cproject deleted file mode 100644 index 65b4e4fa7..000000000 --- a/contrib/other-builds/util/.cproject +++ /dev/null @@ -1,149 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/lm/builder/pipeline.cc b/lm/builder/pipeline.cc index da82c22e7..e91870808 100644 --- a/lm/builder/pipeline.cc +++ b/lm/builder/pipeline.cc @@ -302,33 +302,40 @@ void Pipeline(PipelineConfig config, int text_file, int out_arpa) { "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); UTIL_TIMER("(%w s) Total wall time elapsed\n"); + Master master(config); + // master's destructor will wait for chains. But they might be deadlocked if + // this thread dies because e.g. it ran out of memory. + try { + util::scoped_fd vocab_file(config.vocab_file.empty() ? + util::MakeTemp(config.TempPrefix()) : + util::CreateOrThrow(config.vocab_file.c_str())); + uint64_t token_count; + std::string text_file_name; + CountText(text_file, vocab_file.get(), master, token_count, text_file_name); - util::scoped_fd vocab_file(config.vocab_file.empty() ? - util::MakeTemp(config.TempPrefix()) : - util::CreateOrThrow(config.vocab_file.c_str())); - uint64_t token_count; - std::string text_file_name; - CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + std::vector counts; + std::vector counts_pruned; + std::vector discounts; + master >> AdjustCounts(counts, counts_pruned, discounts, config.prune_thresholds); - std::vector counts; - std::vector counts_pruned; - std::vector discounts; - master >> AdjustCounts(counts, counts_pruned, discounts, config.prune_thresholds); + { + util::FixedArray gammas; + Sorts primary; + InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds); + InterpolateProbabilities(counts_pruned, master, primary, gammas); + } - { - util::FixedArray gammas; - Sorts primary; - InitialProbabilities(counts, counts_pruned, discounts, master, primary, gammas, config.prune_thresholds); - InterpolateProbabilities(counts_pruned, master, primary, gammas); + std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; + VocabReconstitute vocab(vocab_file.get()); + UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); + HeaderInfo header_info(text_file_name, token_count); + master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; + master.MutableChains().Wait(true); + } catch (const util::Exception &e) { + std::cerr << e.what() << std::endl; + abort(); } - - std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; - VocabReconstitute vocab(vocab_file.get()); - UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); - HeaderInfo header_info(text_file_name, token_count); - master >> PrintARPA(vocab, counts_pruned, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; - master.MutableChains().Wait(true); } }} // namespaces diff --git a/mira/Main.cpp b/mira/Main.cpp index c22a80ece..70b5971c9 100644 --- a/mira/Main.cpp +++ b/mira/Main.cpp @@ -665,7 +665,7 @@ int main(int argc, char** argv) } // number of weight dumps this epoch - // size_t weightMixingThisEpoch = 0; + size_t weightMixingThisEpoch = 0; size_t weightEpochDump = 0; size_t shardPosition = 0; diff --git a/misc/CreateProbingPT.cpp b/misc/CreateProbingPT.cpp new file mode 100644 index 000000000..3ea369a96 --- /dev/null +++ b/misc/CreateProbingPT.cpp @@ -0,0 +1,20 @@ +#include "util/usage.hh" +#include "moses/TranslationModel/ProbingPT/storing.hh" + + + +int main(int argc, char* argv[]){ + + if (argc != 3) { + // Tell the user how to run the program + std::cerr << "Provided " << argc << " arguments, needed 3." << std::endl; + std::cerr << "Usage: " << argv[0] << " path_to_phrasetable output_dir" << std::endl; + return 1; + } + + createProbingPT(argv[1], argv[2]); + + util::PrintUsage(std::cout); + return 0; +} + diff --git a/misc/Jamfile b/misc/Jamfile index 76f91babb..d466e306c 100644 --- a/misc/Jamfile +++ b/misc/Jamfile @@ -25,4 +25,15 @@ else { alias programsMin ; } -alias programs : 1-1-Extraction TMining generateSequences processPhraseTable processLexicalTable queryPhraseTable queryLexicalTable programsMin ; +if [ option.get "with-probing-pt" : : "yes" ] +{ + exe CreateProbingPT : CreateProbingPT.cpp ../moses//moses ; + exe QueryProbingPT : QueryProbingPT.cpp ../moses//moses ; + + alias programsProbing : CreateProbingPT QueryProbingPT ; +} +else { + alias programsProbing ; +} + +alias programs : 1-1-Extraction TMining generateSequences processPhraseTable processLexicalTable queryPhraseTable queryLexicalTable programsMin programsProbing ; diff --git a/misc/QueryProbingPT.cpp b/misc/QueryProbingPT.cpp new file mode 100644 index 000000000..8a3441a0d --- /dev/null +++ b/misc/QueryProbingPT.cpp @@ -0,0 +1,61 @@ +#include "util/file_piece.hh" + +#include "util/file.hh" +#include "util/scoped.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" +#include "util/usage.hh" + +#include "moses/TranslationModel/ProbingPT/quering.hh" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include //For finding size of file +#include +#include +#include +#include +#include + +int main(int argc, char* argv[]) { + if (argc != 2) { + // Tell the user how to run the program + std::cerr << "Usage: " << argv[0] << " path_to_directory" << std::endl; + return 1; + } + + QueryEngine queries(argv[1]); + + //Interactive search + std::cout << "Please enter a string to be searched, or exit to exit." << std::endl; + while (true){ + std::string cinstr = ""; + getline(std::cin, cinstr); + if (cinstr == "exit"){ + break; + }else{ + //Actual lookup + std::pair > query_result; + query_result = queries.query(StringPiece(cinstr)); + + if (query_result.first) { + queries.printTargetInfo(query_result.second); + } else { + std::cout << "Key not found!" << std::endl; + } + } + } + + util::PrintUsage(std::cout); + + return 0; +} diff --git a/moses-chart-cmd/IOWrapper.cpp b/moses-chart-cmd/IOWrapper.cpp index db12dcfd6..faf6a1f93 100644 --- a/moses-chart-cmd/IOWrapper.cpp +++ b/moses-chart-cmd/IOWrapper.cpp @@ -150,6 +150,7 @@ IOWrapper::~IOWrapper() delete m_outputSearchGraphStream; delete m_detailedTranslationReportingStream; delete m_detailedTreeFragmentsTranslationReportingStream; + delete m_detailTreeFragmentsOutputCollector; delete m_alignmentInfoStream; delete m_unknownsStream; delete m_detailOutputCollector; @@ -412,10 +413,9 @@ void IOWrapper::OutputTreeFragmentsTranslationOptions(std::ostream &out, Applica OutputTranslationOption(out, applicationContext, hypo, sentence, translationId); const TargetPhrase &currTarPhr = hypo->GetCurrTargetPhrase(); - boost::shared_ptr property; out << " ||| "; - if (currTarPhr.GetProperty("Tree", property)) { + if (const PhraseProperty *property = currTarPhr.GetProperty("Tree")) { out << " " << property->GetValueString(); } else { out << " " << "noTreeInfo"; @@ -439,10 +439,9 @@ void IOWrapper::OutputTreeFragmentsTranslationOptions(std::ostream &out, Applica OutputTranslationOption(out, applicationContext, applied, sentence, translationId); const TargetPhrase &currTarPhr = *static_cast(applied->GetNote().vp); - boost::shared_ptr property; out << " ||| "; - if (currTarPhr.GetProperty("Tree", property)) { + if (const PhraseProperty *property = currTarPhr.GetProperty("Tree")) { out << " " << property->GetValueString(); } else { out << " " << "noTreeInfo"; diff --git a/moses/ChartCellCollection.h b/moses/ChartCellCollection.h index d0423b0b2..1edeb4450 100644 --- a/moses/ChartCellCollection.h +++ b/moses/ChartCellCollection.h @@ -20,11 +20,11 @@ ***********************************************************************/ #pragma once +#include #include "InputType.h" #include "ChartCell.h" #include "WordsRange.h" - -#include +#include "InputPath.h" namespace Moses { @@ -36,6 +36,7 @@ class ChartCellCollectionBase public: template ChartCellCollectionBase(const InputType &input, const Factory &factory) : m_cells(input.GetSize()) { + size_t size = input.GetSize(); for (size_t startPos = 0; startPos < size; ++startPos) { std::vector &inner = m_cells[startPos]; @@ -47,12 +48,15 @@ public: * gets it from there :-(. The span is actually stored as a reference, * which needs to point somewhere, so I have it refer to the ChartCell. */ - m_source.push_back(new ChartCellLabel(inner[0]->GetCoverage(), input.GetWord(startPos))); + const WordsRange &range = inner[0]->GetCoverage(); + + m_source.push_back(new ChartCellLabel(range, input.GetWord(startPos))); } } virtual ~ChartCellCollectionBase(); + const ChartCellBase &GetBase(const WordsRange &coverage) const { return *m_cells[coverage.GetStartPos()][coverage.GetEndPos() - coverage.GetStartPos()]; } @@ -70,6 +74,7 @@ private: std::vector > m_cells; boost::ptr_vector m_source; + }; /** Hold all the chart cells for 1 input sentence. A variable of this type is held by the ChartManager diff --git a/moses/ChartCellLabel.h b/moses/ChartCellLabel.h index 144a64add..c67d985b2 100644 --- a/moses/ChartCellLabel.h +++ b/moses/ChartCellLabel.h @@ -90,6 +90,7 @@ public: private: const WordsRange &m_coverage; const Word &m_label; + //const InputPath &m_inputPath; Stack m_stack; mutable float m_bestScore; }; diff --git a/moses/ChartCellLabelSet.h b/moses/ChartCellLabelSet.h index 2b497b957..45d281d35 100644 --- a/moses/ChartCellLabelSet.h +++ b/moses/ChartCellLabelSet.h @@ -72,6 +72,8 @@ public: size_t idx = w[0]->GetId(); if (! ChartCellExists(idx)) { m_size++; + + m_map[idx] = new ChartCellLabel(m_coverage, w); } } diff --git a/moses/ChartHypothesis.cpp b/moses/ChartHypothesis.cpp index 212a28d23..7b32559b7 100644 --- a/moses/ChartHypothesis.cpp +++ b/moses/ChartHypothesis.cpp @@ -149,6 +149,40 @@ Phrase ChartHypothesis::GetOutputPhrase() const return outPhrase; } +void ChartHypothesis::GetOutputPhrase(int leftRightMost, int numWords, Phrase &outPhrase) const +{ + const TargetPhrase &tp = GetCurrTargetPhrase(); + + int targetSize = tp.GetSize(); + for (int i = 0; i < targetSize; ++i) { + int pos; + if (leftRightMost == 1) { + pos = i; + } + else if (leftRightMost == 2) { + pos = targetSize - i - 1; + } + else { + abort(); + } + + const Word &word = tp.GetWord(pos); + + if (word.IsNonTerminal()) { + // non-term. fill out with prev hypo + size_t nonTermInd = tp.GetAlignNonTerm().GetNonTermIndexMap()[pos]; + const ChartHypothesis *prevHypo = m_prevHypos[nonTermInd]; + prevHypo->GetOutputPhrase(outPhrase); + } else { + outPhrase.AddWord(word); + } + + if (outPhrase.GetSize() >= numWords) { + return; + } + } +} + /** check, if two hypothesis can be recombined. this is actually a sorting function that allows us to keep an ordered list of hypotheses. This makes recombination @@ -200,7 +234,7 @@ void ChartHypothesis::Evaluate() StatelessFeatureFunction::GetStatelessFeatureFunctions(); for (unsigned i = 0; i < sfs.size(); ++i) { if (! staticData.IsFeatureFunctionIgnored( *sfs[i] )) { - sfs[i]->EvaluateChart(*this,&m_scoreBreakdown); + sfs[i]->EvaluateWhenApplied(*this,&m_scoreBreakdown); } } @@ -208,7 +242,7 @@ void ChartHypothesis::Evaluate() StatefulFeatureFunction::GetStatefulFeatureFunctions(); for (unsigned i = 0; i < ffs.size(); ++i) { if (! staticData.IsFeatureFunctionIgnored( *ffs[i] )) { - m_ffStates[i] = ffs[i]->EvaluateChart(*this,i,&m_scoreBreakdown); + m_ffStates[i] = ffs[i]->EvaluateWhenApplied(*this,i,&m_scoreBreakdown); } } diff --git a/moses/ChartHypothesis.h b/moses/ChartHypothesis.h index 12050e764..3f159d222 100644 --- a/moses/ChartHypothesis.h +++ b/moses/ChartHypothesis.h @@ -138,6 +138,10 @@ public: void GetOutputPhrase(Phrase &outPhrase) const; Phrase GetOutputPhrase() const; + // get leftmost/rightmost words only + // leftRightMost: 1=left, 2=right + void GetOutputPhrase(int leftRightMost, int numWords, Phrase &outPhrase) const; + int RecombineCompare(const ChartHypothesis &compare) const; void Evaluate(); diff --git a/moses/ChartManager.cpp b/moses/ChartManager.cpp index 139256171..e137da915 100644 --- a/moses/ChartManager.cpp +++ b/moses/ChartManager.cpp @@ -125,7 +125,7 @@ void ChartManager::ProcessSentence() */ void ChartManager::AddXmlChartOptions() { - const StaticData &staticData = StaticData::Instance(); + // const StaticData &staticData = StaticData::Instance(); const std::vector xmlChartOptionsList = m_source.GetXmlChartTranslationOptions(); IFVERBOSE(2) { diff --git a/moses/ChartTranslationOption.cpp b/moses/ChartTranslationOption.cpp index 0fece0a09..daf1f89ce 100644 --- a/moses/ChartTranslationOption.cpp +++ b/moses/ChartTranslationOption.cpp @@ -18,7 +18,7 @@ void ChartTranslationOption::Evaluate(const InputType &input, for (size_t i = 0; i < ffs.size(); ++i) { const FeatureFunction &ff = *ffs[i]; - ff.Evaluate(input, inputPath, m_targetPhrase, &stackVec, m_scoreBreakdown); + ff.EvaluateWithSourceContext(input, inputPath, m_targetPhrase, &stackVec, m_scoreBreakdown); } } diff --git a/moses/ChartTranslationOptionList.cpp b/moses/ChartTranslationOptionList.cpp index f0f6a0732..8c52eca32 100644 --- a/moses/ChartTranslationOptionList.cpp +++ b/moses/ChartTranslationOptionList.cpp @@ -83,7 +83,7 @@ void ChartTranslationOptionList::Add(const TargetPhraseCollection &tpc, // If the rule limit has already been reached then don't add the option // unless it is better than at least one existing option. - if (m_size > m_ruleLimit && score < m_scoreThreshold) { + if (m_ruleLimit && m_size > m_ruleLimit && score < m_scoreThreshold) { return; } @@ -100,12 +100,12 @@ void ChartTranslationOptionList::Add(const TargetPhraseCollection &tpc, ++m_size; // If the rule limit hasn't been exceeded then update the threshold. - if (m_size <= m_ruleLimit) { + if (!m_ruleLimit || m_size <= m_ruleLimit) { m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold; } // Prune if bursting - if (m_size == m_ruleLimit * 2) { + if (m_ruleLimit && m_size == m_ruleLimit * 2) { NTH_ELEMENT4(m_collection.begin(), m_collection.begin() + m_ruleLimit - 1, m_collection.begin() + m_size, @@ -126,7 +126,7 @@ void ChartTranslationOptionList::AddPhraseOOV(TargetPhrase &phrase, std::list m_ruleLimit) { + if (m_ruleLimit && m_size > m_ruleLimit) { // Something's gone wrong if the list has grown to m_ruleLimit * 2 // without being pruned. assert(m_size < m_ruleLimit * 2); diff --git a/moses/ChartTranslationOptions.cpp b/moses/ChartTranslationOptions.cpp index 3073bdcfc..114eae868 100644 --- a/moses/ChartTranslationOptions.cpp +++ b/moses/ChartTranslationOptions.cpp @@ -74,7 +74,7 @@ void ChartTranslationOptions::Evaluate(const InputType &input, const InputPath & ++numDiscard; } else if (numDiscard) { - m_collection[i - numDiscard] = boost::shared_ptr(transOpt); + m_collection[i - numDiscard] = m_collection[i]; } } diff --git a/moses/ConfusionNet.cpp b/moses/ConfusionNet.cpp index 5861ee5f1..d9270bd1b 100644 --- a/moses/ConfusionNet.cpp +++ b/moses/ConfusionNet.cpp @@ -142,7 +142,7 @@ namespace Moses { Clear(); - const StaticData &staticData = StaticData::Instance(); + // const StaticData &staticData = StaticData::Instance(); const InputFeature &inputFeature = InputFeature::Instance(); size_t numInputScores = inputFeature.GetNumInputScores(); size_t numRealWordCount = inputFeature.GetNumRealWordsInInput(); diff --git a/moses/FF/BleuScoreFeature.cpp b/moses/FF/BleuScoreFeature.cpp index 348eaa0ea..0d0a20797 100644 --- a/moses/FF/BleuScoreFeature.cpp +++ b/moses/FF/BleuScoreFeature.cpp @@ -502,7 +502,7 @@ void BleuScoreFeature::GetClippedNgramMatchesAndCounts(Phrase& phrase, * Given a previous state, compute Bleu score for the updated state with an additional target * phrase translated. */ -FFState* BleuScoreFeature::Evaluate(const Hypothesis& cur_hypo, +FFState* BleuScoreFeature::EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const { @@ -563,7 +563,7 @@ FFState* BleuScoreFeature::Evaluate(const Hypothesis& cur_hypo, return new_state; } -FFState* BleuScoreFeature::EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, +FFState* BleuScoreFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator ) const { if (!m_enabled) return new BleuScoreState(); diff --git a/moses/FF/BleuScoreFeature.h b/moses/FF/BleuScoreFeature.h index 99f04f5ff..cdba578ac 100644 --- a/moses/FF/BleuScoreFeature.h +++ b/moses/FF/BleuScoreFeature.h @@ -115,20 +115,20 @@ public: std::vector< size_t >&, size_t skip = 0) const; - FFState* Evaluate( const Hypothesis& cur_hypo, + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - FFState* EvaluateChart(const ChartHypothesis& cur_hypo, + FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/ConstrainedDecoding.cpp b/moses/FF/ConstrainedDecoding.cpp index 9a8ecd1c3..bfe412913 100644 --- a/moses/FF/ConstrainedDecoding.cpp +++ b/moses/FF/ConstrainedDecoding.cpp @@ -100,7 +100,7 @@ const std::vector *GetConstraint(const std::map } } -FFState* ConstrainedDecoding::Evaluate( +FFState* ConstrainedDecoding::EvaluateWhenApplied( const Hypothesis& hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const @@ -143,7 +143,7 @@ FFState* ConstrainedDecoding::Evaluate( return ret; } -FFState* ConstrainedDecoding::EvaluateChart( +FFState* ConstrainedDecoding::EvaluateWhenApplied( const ChartHypothesis &hypo, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const diff --git a/moses/FF/ConstrainedDecoding.h b/moses/FF/ConstrainedDecoding.h index 2db192ce8..ca007f21d 100644 --- a/moses/FF/ConstrainedDecoding.h +++ b/moses/FF/ConstrainedDecoding.h @@ -41,13 +41,13 @@ public: return true; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -55,12 +55,12 @@ public: , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - FFState* Evaluate( + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - FFState* EvaluateChart( + FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const; diff --git a/moses/FF/ControlRecombination.cpp b/moses/FF/ControlRecombination.cpp index d3e7c82ab..85e88ac94 100644 --- a/moses/FF/ControlRecombination.cpp +++ b/moses/FF/ControlRecombination.cpp @@ -56,7 +56,7 @@ std::vector ControlRecombination::DefaultWeights() const return ret; } -FFState* ControlRecombination::Evaluate( +FFState* ControlRecombination::EvaluateWhenApplied( const Hypothesis& hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const @@ -64,7 +64,7 @@ FFState* ControlRecombination::Evaluate( return new ControlRecombinationState(hypo, *this); } -FFState* ControlRecombination::EvaluateChart( +FFState* ControlRecombination::EvaluateWhenApplied( const ChartHypothesis &hypo, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const diff --git a/moses/FF/ControlRecombination.h b/moses/FF/ControlRecombination.h index 0100d500d..095cc6b29 100644 --- a/moses/FF/ControlRecombination.h +++ b/moses/FF/ControlRecombination.h @@ -57,24 +57,24 @@ public: return true; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - FFState* Evaluate( + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - FFState* EvaluateChart( + FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const; diff --git a/moses/FF/CountNonTerms.cpp b/moses/FF/CountNonTerms.cpp index 92b79cd5d..03c7b7315 100644 --- a/moses/FF/CountNonTerms.cpp +++ b/moses/FF/CountNonTerms.cpp @@ -16,7 +16,7 @@ CountNonTerms::CountNonTerms(const std::string &line) ReadParameters(); } -void CountNonTerms::Evaluate(const Phrase &sourcePhrase +void CountNonTerms::EvaluateInIsolation(const Phrase &sourcePhrase , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/CountNonTerms.h b/moses/FF/CountNonTerms.h index 1fe71745d..c4e1467e9 100644 --- a/moses/FF/CountNonTerms.h +++ b/moses/FF/CountNonTerms.h @@ -12,12 +12,12 @@ public: bool IsUseable(const FactorMask &mask) const { return true; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -25,11 +25,11 @@ public: , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart( + void EvaluateWhenApplied( const ChartHypothesis& hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/CoveredReferenceFeature.cpp b/moses/FF/CoveredReferenceFeature.cpp index 25ab829f8..3a2482d0d 100644 --- a/moses/FF/CoveredReferenceFeature.cpp +++ b/moses/FF/CoveredReferenceFeature.cpp @@ -40,13 +40,13 @@ int CoveredReferenceState::Compare(const FFState& other) const // return (m_coveredRef.size() < otherState.m_coveredRef.size()) ? -1 : +1; } -void CoveredReferenceFeature::Evaluate(const Phrase &source +void CoveredReferenceFeature::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} -void CoveredReferenceFeature::Evaluate(const InputType &input +void CoveredReferenceFeature::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -90,7 +90,7 @@ void CoveredReferenceFeature::SetParameter(const std::string& key, const std::st } } -FFState* CoveredReferenceFeature::Evaluate( +FFState* CoveredReferenceFeature::EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const @@ -131,7 +131,7 @@ FFState* CoveredReferenceFeature::Evaluate( return ret; } -FFState* CoveredReferenceFeature::EvaluateChart( +FFState* CoveredReferenceFeature::EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const diff --git a/moses/FF/CoveredReferenceFeature.h b/moses/FF/CoveredReferenceFeature.h index cd2b2f966..a6cdd6f99 100644 --- a/moses/FF/CoveredReferenceFeature.h +++ b/moses/FF/CoveredReferenceFeature.h @@ -52,21 +52,21 @@ public: return new CoveredReferenceState(); } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const; - FFState* Evaluate( + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - FFState* EvaluateChart( + FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const; diff --git a/moses/FF/DecodeFeature.h b/moses/FF/DecodeFeature.h index d79598328..ac4e9392b 100644 --- a/moses/FF/DecodeFeature.h +++ b/moses/FF/DecodeFeature.h @@ -62,20 +62,20 @@ public: bool IsUseable(const FactorMask &mask) const; void SetParameter(const std::string& key, const std::string& value); - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/DistortionScoreProducer.cpp b/moses/FF/DistortionScoreProducer.cpp index 303f35236..5995fe213 100644 --- a/moses/FF/DistortionScoreProducer.cpp +++ b/moses/FF/DistortionScoreProducer.cpp @@ -87,7 +87,7 @@ float DistortionScoreProducer::CalculateDistortionScore(const Hypothesis& hypo, } -FFState* DistortionScoreProducer::Evaluate( +FFState* DistortionScoreProducer::EvaluateWhenApplied( const Hypothesis& hypo, const FFState* prev_state, ScoreComponentCollection* out) const diff --git a/moses/FF/DistortionScoreProducer.h b/moses/FF/DistortionScoreProducer.h index 1bc6493e2..aa2c18b95 100644 --- a/moses/FF/DistortionScoreProducer.h +++ b/moses/FF/DistortionScoreProducer.h @@ -28,26 +28,26 @@ public: virtual const FFState* EmptyHypothesisState(const InputType &input) const; - virtual FFState* Evaluate( + virtual FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - virtual FFState* EvaluateChart( + virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection*) const { throw std::logic_error("DistortionScoreProducer not supported in chart decoder, yet"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/ExternalFeature.cpp b/moses/FF/ExternalFeature.cpp index 141541170..10800d24d 100644 --- a/moses/FF/ExternalFeature.cpp +++ b/moses/FF/ExternalFeature.cpp @@ -51,7 +51,7 @@ void ExternalFeature::SetParameter(const std::string& key, const std::string& va } } -FFState* ExternalFeature::Evaluate( +FFState* ExternalFeature::EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const @@ -59,7 +59,7 @@ FFState* ExternalFeature::Evaluate( return new ExternalFeatureState(m_stateSize); } -FFState* ExternalFeature::EvaluateChart( +FFState* ExternalFeature::EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const diff --git a/moses/FF/ExternalFeature.h b/moses/FF/ExternalFeature.h index 19eb45f2a..a8916a853 100644 --- a/moses/FF/ExternalFeature.h +++ b/moses/FF/ExternalFeature.h @@ -51,24 +51,24 @@ public: void SetParameter(const std::string& key, const std::string& value); - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - FFState* Evaluate( + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - FFState* EvaluateChart( + FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const; diff --git a/moses/FF/Factory.cpp b/moses/FF/Factory.cpp index 347e34ad9..7eb53a40a 100644 --- a/moses/FF/Factory.cpp +++ b/moses/FF/Factory.cpp @@ -10,6 +10,7 @@ #include "moses/TranslationModel/PhraseDictionaryDynSuffixArray.h" #include "moses/TranslationModel/PhraseDictionaryScope3.h" #include "moses/TranslationModel/PhraseDictionaryTransliteration.h" +#include "moses/TranslationModel/RuleTable/PhraseDictionaryFuzzyMatch.h" #include "moses/FF/LexicalReordering/LexicalReordering.h" @@ -26,6 +27,7 @@ #include "moses/FF/PhrasePairFeature.h" #include "moses/FF/PhraseLengthFeature.h" #include "moses/FF/DistortionScoreProducer.h" +#include "moses/FF/SparseHieroReorderingFeature.h" #include "moses/FF/WordPenaltyProducer.h" #include "moses/FF/InputFeature.h" #include "moses/FF/PhrasePenalty.h" @@ -36,6 +38,7 @@ #include "moses/FF/CoveredReferenceFeature.h" #include "moses/FF/TreeStructureFeature.h" #include "moses/FF/SoftMatchingFeature.h" +#include "moses/FF/SourceGHKMTreeInputMatchFeature.h" #include "moses/FF/HyperParameterAsWeight.h" #include "moses/FF/SetSourcePhrase.h" #include "CountNonTerms.h" @@ -43,6 +46,8 @@ #include "RuleScope.h" #include "MaxSpanFreeNonTermSource.h" #include "NieceTerminal.h" +#include "SpanLength.h" +#include "SyntaxRHS.h" #include "moses/FF/SkeletonStatelessFF.h" #include "moses/FF/SkeletonStatefulFF.h" @@ -55,6 +60,9 @@ #ifdef PT_UG #include "moses/TranslationModel/UG/mmsapt.h" #endif +#ifdef HAVE_PROBINGPT +#include "moses/TranslationModel/ProbingPT/ProbingPT.h" +#endif #include "moses/LM/Ken.h" #ifdef LM_IRST @@ -85,6 +93,10 @@ #include "moses/LM/DALMWrapper.h" #endif +#ifdef LM_LBL +#include "moses/LM/oxlm/LBLLM.h" +#endif + #include "util/exception.hh" #include @@ -148,6 +160,18 @@ FeatureRegistry::FeatureRegistry() #define MOSES_FNAME(name) Add(#name, new DefaultFeatureFactory< name >()); // Feature with different name than class. #define MOSES_FNAME2(name, type) Add(name, new DefaultFeatureFactory< type >()); + + MOSES_FNAME2("PhraseDictionaryBinary", PhraseDictionaryTreeAdaptor); + MOSES_FNAME(PhraseDictionaryOnDisk); + MOSES_FNAME(PhraseDictionaryMemory); + MOSES_FNAME(PhraseDictionaryScope3); + MOSES_FNAME(PhraseDictionaryMultiModel); + MOSES_FNAME(PhraseDictionaryMultiModelCounts); + MOSES_FNAME(PhraseDictionaryALSuffixArray); + MOSES_FNAME(PhraseDictionaryDynSuffixArray); + MOSES_FNAME(PhraseDictionaryTransliteration); + MOSES_FNAME(PhraseDictionaryFuzzyMatch); + MOSES_FNAME(GlobalLexicalModel); //MOSES_FNAME(GlobalLexicalModelUnlimited); This was commented out in the original MOSES_FNAME(SourceWordDeletionFeature); @@ -164,15 +188,6 @@ FeatureRegistry::FeatureRegistry() MOSES_FNAME2("Distortion", DistortionScoreProducer); MOSES_FNAME2("WordPenalty", WordPenaltyProducer); MOSES_FNAME(InputFeature); - MOSES_FNAME2("PhraseDictionaryBinary", PhraseDictionaryTreeAdaptor); - MOSES_FNAME(PhraseDictionaryOnDisk); - MOSES_FNAME(PhraseDictionaryMemory); - MOSES_FNAME(PhraseDictionaryScope3); - MOSES_FNAME(PhraseDictionaryMultiModel); - MOSES_FNAME(PhraseDictionaryMultiModelCounts); - MOSES_FNAME(PhraseDictionaryALSuffixArray); - MOSES_FNAME(PhraseDictionaryDynSuffixArray); - MOSES_FNAME(PhraseDictionaryTransliteration); MOSES_FNAME(OpSequenceModel); MOSES_FNAME(PhrasePenalty); MOSES_FNAME2("UnknownWordPenalty", UnknownWordPenaltyProducer); @@ -180,6 +195,7 @@ FeatureRegistry::FeatureRegistry() MOSES_FNAME(ConstrainedDecoding); MOSES_FNAME(CoveredReferenceFeature); MOSES_FNAME(ExternalFeature); + MOSES_FNAME(SourceGHKMTreeInputMatchFeature); MOSES_FNAME(TreeStructureFeature); MOSES_FNAME(SoftMatchingFeature); MOSES_FNAME(HyperParameterAsWeight); @@ -189,6 +205,9 @@ FeatureRegistry::FeatureRegistry() MOSES_FNAME(RuleScope); MOSES_FNAME(MaxSpanFreeNonTermSource); MOSES_FNAME(NieceTerminal); + MOSES_FNAME(SparseHieroReorderingFeature); + MOSES_FNAME(SpanLength); + MOSES_FNAME(SyntaxRHS); MOSES_FNAME(SkeletonStatelessFF); MOSES_FNAME(SkeletonStatefulFF); @@ -201,6 +220,10 @@ FeatureRegistry::FeatureRegistry() #ifdef PT_UG MOSES_FNAME(Mmsapt); #endif +#ifdef HAVE_PROBINGPT + MOSES_FNAME(ProbingPT); +#endif + #ifdef HAVE_SYNLM MOSES_FNAME(SyntacticLanguageModel); #endif @@ -222,6 +245,11 @@ FeatureRegistry::FeatureRegistry() #ifdef LM_DALM MOSES_FNAME2("DALM", LanguageModelDALM); #endif +#ifdef LM_LBL + MOSES_FNAME2("LBLLM-LM", LBLLM); + MOSES_FNAME2("LBLLM-FactoredLM", LBLLM); + MOSES_FNAME2("LBLLM-FactoredMaxentLM", LBLLM); +#endif Add("KENLM", new KenFactory()); } @@ -250,12 +278,21 @@ void FeatureRegistry::Construct(const std::string &name, const std::string &line void FeatureRegistry::PrintFF() const { + vector ffs; std::cerr << "Available feature functions:" << std::endl; Map::const_iterator iter; for (iter = registry_.begin(); iter != registry_.end(); ++iter) { const string &ffName = iter->first; + ffs.push_back(ffName); + } + + vector::const_iterator iterVec; + std::sort(ffs.begin(), ffs.end()); + for (iterVec = ffs.begin(); iterVec != ffs.end(); ++iterVec) { + const string &ffName = *iterVec; std::cerr << ffName << " "; } + std::cerr << std::endl; } diff --git a/moses/FF/FeatureFunction.h b/moses/FF/FeatureFunction.h index 18b016c8f..42ac12974 100644 --- a/moses/FF/FeatureFunction.h +++ b/moses/FF/FeatureFunction.h @@ -98,7 +98,7 @@ public: // source phrase is the substring that the phrase table uses to look up the target phrase, // may have more factors than actually need, but not guaranteed. // For SCFG decoding, the source contains non-terminals, NOT the raw source from the input sentence - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const = 0; @@ -110,7 +110,7 @@ public: // It is guaranteed to be in the same order as the non-terms in the source phrase. // For pb models, stackvec is NULL. // No FF should set estimatedFutureScore in both overloads! - virtual void Evaluate(const InputType &input + virtual void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec diff --git a/moses/FF/GlobalLexicalModel.cpp b/moses/FF/GlobalLexicalModel.cpp index ff9e87bb0..f6eb165a8 100644 --- a/moses/FF/GlobalLexicalModel.cpp +++ b/moses/FF/GlobalLexicalModel.cpp @@ -165,7 +165,7 @@ float GlobalLexicalModel::GetFromCacheOrScorePhrase( const TargetPhrase& targetP return score; } -void GlobalLexicalModel::Evaluate +void GlobalLexicalModel::EvaluateWhenApplied (const Hypothesis& hypo, ScoreComponentCollection* accumulator) const { diff --git a/moses/FF/GlobalLexicalModel.h b/moses/FF/GlobalLexicalModel.h index 664835df5..151dbf472 100644 --- a/moses/FF/GlobalLexicalModel.h +++ b/moses/FF/GlobalLexicalModel.h @@ -70,24 +70,24 @@ public: bool IsUseable(const FactorMask &mask) const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const; - void EvaluateChart( + void EvaluateWhenApplied( const ChartHypothesis& hypo, ScoreComponentCollection* accumulator) const { throw std::logic_error("GlobalLexicalModel not supported in chart decoder, yet"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/GlobalLexicalModelUnlimited.cpp b/moses/FF/GlobalLexicalModelUnlimited.cpp index a6883a7e8..c8dbd5883 100644 --- a/moses/FF/GlobalLexicalModelUnlimited.cpp +++ b/moses/FF/GlobalLexicalModelUnlimited.cpp @@ -108,7 +108,7 @@ void GlobalLexicalModelUnlimited::InitializeForInput( Sentence const& in ) m_local->input = ∈ } -void GlobalLexicalModelUnlimited::Evaluate(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const +void GlobalLexicalModelUnlimited::EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const { const Sentence& input = *(m_local->input); const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); diff --git a/moses/FF/GlobalLexicalModelUnlimited.h b/moses/FF/GlobalLexicalModelUnlimited.h index f12df7d61..096254613 100644 --- a/moses/FF/GlobalLexicalModelUnlimited.h +++ b/moses/FF/GlobalLexicalModelUnlimited.h @@ -81,23 +81,23 @@ public: //TODO: This implements the old interface, but cannot be updated because //it appears to be stateful - void Evaluate(const Hypothesis& cur_hypo, + void EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const; - void EvaluateChart(const ChartHypothesis& /* cur_hypo */, + void EvaluateWhenApplied(const ChartHypothesis& /* cur_hypo */, int /* featureID */, ScoreComponentCollection* ) const { throw std::logic_error("GlobalLexicalModelUnlimited not supported in chart decoder, yet"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/HyperParameterAsWeight.h b/moses/FF/HyperParameterAsWeight.h index 9db375c0f..aaad21c14 100644 --- a/moses/FF/HyperParameterAsWeight.h +++ b/moses/FF/HyperParameterAsWeight.h @@ -17,13 +17,13 @@ public: virtual bool IsUseable(const FactorMask &mask) const { return true; } - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} - virtual void Evaluate(const InputType &input + virtual void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -31,14 +31,14 @@ public: , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - virtual void Evaluate(const Hypothesis& hypo, + virtual void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} /** * Same for chart-based features. **/ - virtual void EvaluateChart(const ChartHypothesis &hypo, + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/InputFeature.cpp b/moses/FF/InputFeature.cpp index 0fa2005d1..61753c595 100644 --- a/moses/FF/InputFeature.cpp +++ b/moses/FF/InputFeature.cpp @@ -44,7 +44,7 @@ void InputFeature::SetParameter(const std::string& key, const std::string& value } -void InputFeature::Evaluate(const InputType &input +void InputFeature::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec diff --git a/moses/FF/InputFeature.h b/moses/FF/InputFeature.h index e4b1a8d99..ad4fe398a 100644 --- a/moses/FF/InputFeature.h +++ b/moses/FF/InputFeature.h @@ -41,22 +41,23 @@ public: return m_numRealWordCount; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} - void Evaluate(const InputType &input + + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/InternalStructStatelessFF.cpp b/moses/FF/InternalStructStatelessFF.cpp index 06014a1cf..a050bd8ef 100644 --- a/moses/FF/InternalStructStatelessFF.cpp +++ b/moses/FF/InternalStructStatelessFF.cpp @@ -5,7 +5,7 @@ using namespace std; namespace Moses { -void InternalStructStatelessFF::Evaluate(const Phrase &source +void InternalStructStatelessFF::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const @@ -15,7 +15,7 @@ void InternalStructStatelessFF::Evaluate(const Phrase &source } -void InternalStructStatelessFF::Evaluate(const InputType &input +void InternalStructStatelessFF::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec diff --git a/moses/FF/InternalStructStatelessFF.h b/moses/FF/InternalStructStatelessFF.h index a0ea3f712..2ed8801e2 100644 --- a/moses/FF/InternalStructStatelessFF.h +++ b/moses/FF/InternalStructStatelessFF.h @@ -16,21 +16,21 @@ public: bool IsUseable(const FactorMask &mask) const { return true; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const; - virtual void Evaluate(const Hypothesis& hypo, + virtual void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/LexicalReordering/LexicalReordering.cpp b/moses/FF/LexicalReordering/LexicalReordering.cpp index 6a2a488d9..426a7d91c 100644 --- a/moses/FF/LexicalReordering/LexicalReordering.cpp +++ b/moses/FF/LexicalReordering/LexicalReordering.cpp @@ -14,11 +14,12 @@ LexicalReordering::LexicalReordering(const std::string &line) { std::cerr << "Initializing LexicalReordering.." << std::endl; + map sparseArgs; for (size_t i = 0; i < m_args.size(); ++i) { const vector &args = m_args[i]; if (args[0] == "type") { - m_configuration = new LexicalReorderingConfiguration(args[1]); + m_configuration.reset(new LexicalReorderingConfiguration(args[1])); m_configuration->SetScoreProducer(this); m_modelTypeString = m_configuration->GetModelString(); } else if (args[0] == "input-factor") { @@ -27,8 +28,10 @@ LexicalReordering::LexicalReordering(const std::string &line) m_factorsE =Tokenize(args[1]); } else if (args[0] == "path") { m_filePath = args[1]; + } else if (args[0].substr(0,7) == "sparse-") { + sparseArgs[args[0].substr(7)] = args[1]; } else { - throw "Unknown argument " + args[0]; + UTIL_THROW(util::Exception,"Unknown argument " + args[0]); } } @@ -36,29 +39,29 @@ LexicalReordering::LexicalReordering(const std::string &line) case LexicalReorderingConfiguration::FE: case LexicalReorderingConfiguration::E: if(m_factorsE.empty()) { - throw "TL factor mask for lexical reordering is unexpectedly empty"; + UTIL_THROW(util::Exception,"TL factor mask for lexical reordering is unexpectedly empty"); } if(m_configuration->GetCondition() == LexicalReorderingConfiguration::E) break; // else fall through case LexicalReorderingConfiguration::F: if(m_factorsF.empty()) { - throw "SL factor mask for lexical reordering is unexpectedly empty"; + UTIL_THROW(util::Exception,"SL factor mask for lexical reordering is unexpectedly empty"); } break; default: - throw "Unknown conditioning option!"; + UTIL_THROW(util::Exception,"Unknown conditioning option!"); } + + m_configuration->ConfigureSparse(sparseArgs, this); } LexicalReordering::~LexicalReordering() { - delete m_table; - delete m_configuration; } void LexicalReordering::Load() { - m_table = LexicalReorderingTable::LoadAvailable(m_filePath, m_factorsF, m_factorsE, std::vector()); + m_table.reset(LexicalReorderingTable::LoadAvailable(m_filePath, m_factorsF, m_factorsE, std::vector())); } Scores LexicalReordering::GetProb(const Phrase& f, const Phrase& e) const @@ -66,13 +69,13 @@ Scores LexicalReordering::GetProb(const Phrase& f, const Phrase& e) const return m_table->GetScore(f, e, Phrase(ARRAY_SIZE_INCR)); } -FFState* LexicalReordering::Evaluate(const Hypothesis& hypo, +FFState* LexicalReordering::EvaluateWhenApplied(const Hypothesis& hypo, const FFState* prev_state, ScoreComponentCollection* out) const { Scores score(GetNumScoreComponents(), 0); const LexicalReorderingState *prev = dynamic_cast(prev_state); - LexicalReorderingState *next_state = prev->Expand(hypo.GetTranslationOption(), score); + LexicalReorderingState *next_state = prev->Expand(hypo.GetTranslationOption(), hypo.GetInput(), out); out->PlusEquals(this, score); diff --git a/moses/FF/LexicalReordering/LexicalReordering.h b/moses/FF/LexicalReordering/LexicalReordering.h index 4ff0057f0..09d3b73cc 100644 --- a/moses/FF/LexicalReordering/LexicalReordering.h +++ b/moses/FF/LexicalReordering/LexicalReordering.h @@ -3,17 +3,20 @@ #include #include +#include #include "moses/Factor.h" #include "moses/Phrase.h" #include "moses/TypeDef.h" #include "moses/Util.h" #include "moses/WordsRange.h" -#include "LexicalReorderingState.h" -#include "LexicalReorderingTable.h" #include "moses/FF/StatefulFeatureFunction.h" #include "util/exception.hh" +#include "LexicalReorderingState.h" +#include "LexicalReorderingTable.h" +#include "SparseReordering.h" + namespace Moses { @@ -42,23 +45,23 @@ public: Scores GetProb(const Phrase& f, const Phrase& e) const; - virtual FFState* Evaluate(const Hypothesis& cur_hypo, + virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - virtual FFState* EvaluateChart(const ChartHypothesis&, + virtual FFState* EvaluateWhenApplied(const ChartHypothesis&, int /* featureID */, ScoreComponentCollection*) const { UTIL_THROW(util::Exception, "LexicalReordering is not valid for chart decoder"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const @@ -69,10 +72,10 @@ private: bool DecodeDirection(std::string s); bool DecodeNumFeatureFunctions(std::string s); - LexicalReorderingConfiguration *m_configuration; + boost::scoped_ptr m_configuration; std::string m_modelTypeString; std::vector m_modelType; - LexicalReorderingTable* m_table; + boost::scoped_ptr m_table; //std::vector m_direction; std::vector m_condition; //std::vector m_scoreOffset; diff --git a/moses/FF/LexicalReordering/LexicalReorderingState.cpp b/moses/FF/LexicalReordering/LexicalReorderingState.cpp index aa29a4a12..fa88fdeab 100644 --- a/moses/FF/LexicalReordering/LexicalReorderingState.cpp +++ b/moses/FF/LexicalReordering/LexicalReorderingState.cpp @@ -5,11 +5,11 @@ #include "moses/FF/FFState.h" #include "moses/Hypothesis.h" #include "moses/WordsRange.h" -#include "moses/ReorderingStack.h" #include "moses/TranslationOption.h" #include "LexicalReordering.h" #include "LexicalReorderingState.h" +#include "ReorderingStack.h" namespace Moses { @@ -38,6 +38,14 @@ size_t LexicalReorderingConfiguration::GetNumScoreComponents() const } } +void LexicalReorderingConfiguration::ConfigureSparse + (const std::map& sparseArgs, const LexicalReordering* producer) +{ + if (sparseArgs.size()) { + m_sparse.reset(new SparseReordering(sparseArgs, producer)); + } +} + void LexicalReorderingConfiguration::SetAdditionalScoreComponents(size_t number) { m_additionalScoreComponents = number; @@ -122,52 +130,52 @@ LexicalReorderingState *LexicalReorderingConfiguration::CreateLexicalReorderingS return new BidirectionalReorderingState(*this, bwd, fwd, 0); } -void LexicalReorderingState::CopyScores(Scores& scores, const TranslationOption &topt, ReorderingType reoType) const +void LexicalReorderingState::CopyScores(ScoreComponentCollection* accum, const TranslationOption &topt, const InputType& input, ReorderingType reoType) const { // don't call this on a bidirectional object UTIL_THROW_IF2(m_direction != LexicalReorderingConfiguration::Backward && m_direction != LexicalReorderingConfiguration::Forward, "Unknown direction: " << m_direction); - const Scores *cachedScores = (m_direction == LexicalReorderingConfiguration::Backward) ? - topt.GetLexReorderingScores(m_configuration.GetScoreProducer()) : m_prevScore; + const TranslationOption* relevantOpt = &topt; + if (m_direction != LexicalReorderingConfiguration::Backward) relevantOpt = m_prevOption; + const Scores *cachedScores = relevantOpt->GetLexReorderingScores(m_configuration.GetScoreProducer()); - // No scores available. TODO: Using a good prior distribution would be nicer. - if(cachedScores == NULL) - return; + if(cachedScores) { + Scores scores(m_configuration.GetScoreProducer()->GetNumScoreComponents(),0); - const Scores &scoreSet = *cachedScores; - if(m_configuration.CollapseScores()) - scores[m_offset] = scoreSet[m_offset + reoType]; - else { - std::fill(scores.begin() + m_offset, scores.begin() + m_offset + m_configuration.GetNumberOfTypes(), 0); - scores[m_offset + reoType] = scoreSet[m_offset + reoType]; + const Scores &scoreSet = *cachedScores; + if(m_configuration.CollapseScores()) + scores[m_offset] = scoreSet[m_offset + reoType]; + else { + std::fill(scores.begin() + m_offset, scores.begin() + m_offset + m_configuration.GetNumberOfTypes(), 0); + scores[m_offset + reoType] = scoreSet[m_offset + reoType]; + } + accum->PlusEquals(m_configuration.GetScoreProducer(), scores); } + + const SparseReordering* sparse = m_configuration.GetSparseReordering(); + if (sparse) sparse->CopyScores(*relevantOpt, m_prevOption, input, reoType, m_direction, accum); + } -void LexicalReorderingState::ClearScores(Scores& scores) const -{ - if(m_configuration.CollapseScores()) - scores[m_offset] = 0; - else - std::fill(scores.begin() + m_offset, scores.begin() + m_offset + m_configuration.GetNumberOfTypes(), 0); -} -int LexicalReorderingState::ComparePrevScores(const Scores *other) const +int LexicalReorderingState::ComparePrevScores(const TranslationOption *other) const { - if(m_prevScore == other) + const Scores* myPrevScores = m_prevOption->GetLexReorderingScores(m_configuration.GetScoreProducer()); + const Scores* otherPrevScores = other->GetLexReorderingScores(m_configuration.GetScoreProducer()); + + if(myPrevScores == otherPrevScores) return 0; // The pointers are NULL if a phrase pair isn't found in the reordering table. - if(other == NULL) + if(otherPrevScores == NULL) return -1; - if(m_prevScore == NULL) + if(myPrevScores == NULL) return 1; - const Scores &my = *m_prevScore; - const Scores &their = *other; for(size_t i = m_offset; i < m_offset + m_configuration.GetNumberOfTypes(); i++) - if(my[i] < their[i]) + if((*myPrevScores)[i] < (*otherPrevScores)[i]) return -1; - else if(my[i] > their[i]) + else if((*myPrevScores)[i] > (*otherPrevScores)[i]) return 1; return 0; @@ -189,11 +197,10 @@ int PhraseBasedReorderingState::Compare(const FFState& o) const if (&o == this) return 0; - const PhraseBasedReorderingState* other = dynamic_cast(&o); - UTIL_THROW_IF2(other == NULL, "Wrong state type"); + const PhraseBasedReorderingState* other = static_cast(&o); if (m_prevRange == other->m_prevRange) { if (m_direction == LexicalReorderingConfiguration::Forward) { - return ComparePrevScores(other->m_prevScore); + return ComparePrevScores(other->m_prevOption); } else { return 0; } @@ -203,27 +210,23 @@ int PhraseBasedReorderingState::Compare(const FFState& o) const return 1; } -LexicalReorderingState* PhraseBasedReorderingState::Expand(const TranslationOption& topt, Scores& scores) const +LexicalReorderingState* PhraseBasedReorderingState::Expand(const TranslationOption& topt, const InputType& input,ScoreComponentCollection* scores) const { ReorderingType reoType; const WordsRange currWordsRange = topt.GetSourceWordsRange(); const LexicalReorderingConfiguration::ModelType modelType = m_configuration.GetModelType(); - if (m_direction == LexicalReorderingConfiguration::Forward && m_first) { - ClearScores(scores); - } else { - if (!m_first || m_useFirstBackwardScore) { - if (modelType == LexicalReorderingConfiguration::MSD) { - reoType = GetOrientationTypeMSD(currWordsRange); - } else if (modelType == LexicalReorderingConfiguration::MSLR) { - reoType = GetOrientationTypeMSLR(currWordsRange); - } else if (modelType == LexicalReorderingConfiguration::Monotonic) { - reoType = GetOrientationTypeMonotonic(currWordsRange); - } else { - reoType = GetOrientationTypeLeftRight(currWordsRange); - } - CopyScores(scores, topt, reoType); + if ((m_direction != LexicalReorderingConfiguration::Forward && m_useFirstBackwardScore) || !m_first) { + if (modelType == LexicalReorderingConfiguration::MSD) { + reoType = GetOrientationTypeMSD(currWordsRange); + } else if (modelType == LexicalReorderingConfiguration::MSLR) { + reoType = GetOrientationTypeMSLR(currWordsRange); + } else if (modelType == LexicalReorderingConfiguration::Monotonic) { + reoType = GetOrientationTypeMonotonic(currWordsRange); + } else { + reoType = GetOrientationTypeLeftRight(currWordsRange); } + CopyScores(scores, topt, input, reoType); } return new PhraseBasedReorderingState(this, topt); @@ -292,7 +295,7 @@ int BidirectionalReorderingState::Compare(const FFState& o) const if (&o == this) return 0; - const BidirectionalReorderingState &other = dynamic_cast(o); + const BidirectionalReorderingState &other = static_cast(o); if(m_backward->Compare(*other.m_backward) < 0) return -1; else if(m_backward->Compare(*other.m_backward) > 0) @@ -301,10 +304,10 @@ int BidirectionalReorderingState::Compare(const FFState& o) const return m_forward->Compare(*other.m_forward); } -LexicalReorderingState* BidirectionalReorderingState::Expand(const TranslationOption& topt, Scores& scores) const +LexicalReorderingState* BidirectionalReorderingState::Expand(const TranslationOption& topt, const InputType& input, ScoreComponentCollection* scores) const { - LexicalReorderingState *newbwd = m_backward->Expand(topt, scores); - LexicalReorderingState *newfwd = m_forward->Expand(topt, scores); + LexicalReorderingState *newbwd = m_backward->Expand(topt,input, scores); + LexicalReorderingState *newfwd = m_forward->Expand(topt, input, scores); return new BidirectionalReorderingState(m_configuration, newbwd, newfwd, m_offset); } @@ -321,11 +324,11 @@ HierarchicalReorderingBackwardState::HierarchicalReorderingBackwardState(const L int HierarchicalReorderingBackwardState::Compare(const FFState& o) const { - const HierarchicalReorderingBackwardState& other = dynamic_cast(o); + const HierarchicalReorderingBackwardState& other = static_cast(o); return m_reoStack.Compare(other.m_reoStack); } -LexicalReorderingState* HierarchicalReorderingBackwardState::Expand(const TranslationOption& topt, Scores& scores) const +LexicalReorderingState* HierarchicalReorderingBackwardState::Expand(const TranslationOption& topt, const InputType& input,ScoreComponentCollection* scores) const { HierarchicalReorderingBackwardState* nextState = new HierarchicalReorderingBackwardState(this, topt, m_reoStack); @@ -344,7 +347,7 @@ LexicalReorderingState* HierarchicalReorderingBackwardState::Expand(const Transl reoType = GetOrientationTypeMonotonic(reoDistance); } - CopyScores(scores, topt, reoType); + CopyScores(scores, topt, input, reoType); return nextState; } @@ -407,11 +410,10 @@ int HierarchicalReorderingForwardState::Compare(const FFState& o) const if (&o == this) return 0; - const HierarchicalReorderingForwardState* other = dynamic_cast(&o); - UTIL_THROW_IF2(other == NULL, "Wrong state type"); + const HierarchicalReorderingForwardState* other = static_cast(&o); if (m_prevRange == other->m_prevRange) { - return ComparePrevScores(other->m_prevScore); + return ComparePrevScores(other->m_prevOption); } else if (m_prevRange < other->m_prevRange) { return -1; } @@ -429,7 +431,7 @@ int HierarchicalReorderingForwardState::Compare(const FFState& o) const // dright: if the next phrase follows the conditioning phrase and other stuff comes in between // dleft: if the next phrase precedes the conditioning phrase and other stuff comes in between -LexicalReorderingState* HierarchicalReorderingForwardState::Expand(const TranslationOption& topt, Scores& scores) const +LexicalReorderingState* HierarchicalReorderingForwardState::Expand(const TranslationOption& topt, const InputType& input,ScoreComponentCollection* scores) const { const LexicalReorderingConfiguration::ModelType modelType = m_configuration.GetModelType(); const WordsRange currWordsRange = topt.GetSourceWordsRange(); @@ -440,7 +442,7 @@ LexicalReorderingState* HierarchicalReorderingForwardState::Expand(const Transla ReorderingType reoType; if (m_first) { - ClearScores(scores); + } else { if (modelType == LexicalReorderingConfiguration::MSD) { reoType = GetOrientationTypeMSD(currWordsRange, coverage); @@ -452,7 +454,7 @@ LexicalReorderingState* HierarchicalReorderingForwardState::Expand(const Transla reoType = GetOrientationTypeLeftRight(currWordsRange, coverage); } - CopyScores(scores, topt, reoType); + CopyScores(scores, topt, input, reoType); } return new HierarchicalReorderingForwardState(this, topt); diff --git a/moses/FF/LexicalReordering/LexicalReorderingState.h b/moses/FF/LexicalReordering/LexicalReorderingState.h index 8e237adc1..e309ed7f1 100644 --- a/moses/FF/LexicalReordering/LexicalReorderingState.h +++ b/moses/FF/LexicalReordering/LexicalReorderingState.h @@ -4,22 +4,25 @@ #include #include +#include + #include "moses/Hypothesis.h" -#include "LexicalReordering.h" +//#include "LexicalReordering.h" +#include "moses/ScoreComponentCollection.h" #include "moses/WordsRange.h" #include "moses/WordsBitmap.h" -#include "moses/ReorderingStack.h" #include "moses/TranslationOption.h" #include "moses/FF/FFState.h" +#include "ReorderingStack.h" namespace Moses { class LexicalReorderingState; class LexicalReordering; +class SparseReordering; /** Factory class for lexical reordering states - * @todo There's a lot of classes for lexicalized reordering. Perhaps put them in a separate dir */ class LexicalReorderingConfiguration { @@ -31,6 +34,8 @@ public: LexicalReorderingConfiguration(const std::string &modelType); + void ConfigureSparse(const std::map& sparseArgs, const LexicalReordering* producer); + LexicalReorderingState *CreateLexicalReorderingState(const InputType &input) const; size_t GetNumScoreComponents() const; @@ -62,6 +67,10 @@ public: return m_collapseScores; } + const SparseReordering* GetSparseReordering() const { + return m_sparse.get(); + } + private: void SetScoreProducer(LexicalReordering* scoreProducer) { m_scoreProducer = scoreProducer; @@ -79,6 +88,7 @@ private: Direction m_direction; Condition m_condition; size_t m_additionalScoreComponents; + boost::scoped_ptr m_sparse; }; //! Abstract class for lexical reordering model states @@ -86,34 +96,35 @@ class LexicalReorderingState : public FFState { public: virtual int Compare(const FFState& o) const = 0; - virtual LexicalReorderingState* Expand(const TranslationOption& hypo, Scores& scores) const = 0; + virtual LexicalReorderingState* Expand(const TranslationOption& hypo, const InputType& input, ScoreComponentCollection* scores) const = 0; static LexicalReorderingState* CreateLexicalReorderingState(const std::vector& config, LexicalReorderingConfiguration::Direction dir, const InputType &input); + typedef int ReorderingType; protected: - typedef int ReorderingType; const LexicalReorderingConfiguration &m_configuration; // The following is the true direction of the object, which can be Backward or Forward even if the Configuration has Bidirectional. LexicalReorderingConfiguration::Direction m_direction; size_t m_offset; - const Scores *m_prevScore; + //forward scores are conditioned on prev option, so need to remember it + const TranslationOption *m_prevOption; inline LexicalReorderingState(const LexicalReorderingState *prev, const TranslationOption &topt) : m_configuration(prev->m_configuration), m_direction(prev->m_direction), m_offset(prev->m_offset), - m_prevScore(topt.GetLexReorderingScores(m_configuration.GetScoreProducer())) {} + m_prevOption(&topt) {} inline LexicalReorderingState(const LexicalReorderingConfiguration &config, LexicalReorderingConfiguration::Direction dir, size_t offset) - : m_configuration(config), m_direction(dir), m_offset(offset), m_prevScore(NULL) {} + : m_configuration(config), m_direction(dir), m_offset(offset), m_prevOption(NULL) {} // copy the right scores in the right places, taking into account forward/backward, offset, collapse - void CopyScores(Scores& scores, const TranslationOption& topt, ReorderingType reoType) const; - void ClearScores(Scores& scores) const; - int ComparePrevScores(const Scores *other) const; + void CopyScores(ScoreComponentCollection* scores, const TranslationOption& topt, const InputType& input, ReorderingType reoType) const; + int ComparePrevScores(const TranslationOption *other) const; //constants for the different type of reorderings (corresponding to indexes in the table file) + public: static const ReorderingType M = 0; // monotonic static const ReorderingType NM = 1; // non-monotonic static const ReorderingType S = 1; // swap @@ -122,6 +133,7 @@ protected: static const ReorderingType DR = 3; // discontinuous, right static const ReorderingType R = 0; // right static const ReorderingType L = 1; // left + static const ReorderingType MAX = 3; //largest possible }; //! @todo what is this? @@ -140,7 +152,7 @@ public: } virtual int Compare(const FFState& o) const; - virtual LexicalReorderingState* Expand(const TranslationOption& topt, Scores& scores) const; + virtual LexicalReorderingState* Expand(const TranslationOption& topt, const InputType& input, ScoreComponentCollection* scores) const; }; //! State for the standard Moses implementation of lexical reordering models @@ -156,7 +168,7 @@ public: PhraseBasedReorderingState(const PhraseBasedReorderingState *prev, const TranslationOption &topt); virtual int Compare(const FFState& o) const; - virtual LexicalReorderingState* Expand(const TranslationOption& topt, Scores& scores) const; + virtual LexicalReorderingState* Expand(const TranslationOption& topt,const InputType& input, ScoreComponentCollection* scores) const; ReorderingType GetOrientationTypeMSD(WordsRange currRange) const; ReorderingType GetOrientationTypeMSLR(WordsRange currRange) const; @@ -177,7 +189,7 @@ public: const TranslationOption &topt, ReorderingStack reoStack); virtual int Compare(const FFState& o) const; - virtual LexicalReorderingState* Expand(const TranslationOption& hypo, Scores& scores) const; + virtual LexicalReorderingState* Expand(const TranslationOption& hypo, const InputType& input, ScoreComponentCollection* scores) const; private: ReorderingType GetOrientationTypeMSD(int reoDistance) const; @@ -200,7 +212,7 @@ public: HierarchicalReorderingForwardState(const HierarchicalReorderingForwardState *prev, const TranslationOption &topt); virtual int Compare(const FFState& o) const; - virtual LexicalReorderingState* Expand(const TranslationOption& hypo, Scores& scores) const; + virtual LexicalReorderingState* Expand(const TranslationOption& hypo, const InputType& input, ScoreComponentCollection* scores) const; private: ReorderingType GetOrientationTypeMSD(WordsRange currRange, WordsBitmap coverage) const; diff --git a/moses/ReorderingStack.cpp b/moses/FF/LexicalReordering/ReorderingStack.cpp similarity index 100% rename from moses/ReorderingStack.cpp rename to moses/FF/LexicalReordering/ReorderingStack.cpp diff --git a/moses/ReorderingStack.h b/moses/FF/LexicalReordering/ReorderingStack.h similarity index 94% rename from moses/ReorderingStack.h rename to moses/FF/LexicalReordering/ReorderingStack.h index 730b17ce3..5a5b80d16 100644 --- a/moses/ReorderingStack.h +++ b/moses/FF/LexicalReordering/ReorderingStack.h @@ -12,7 +12,7 @@ //#include "Phrase.h" //#include "TypeDef.h" //#include "Util.h" -#include "WordsRange.h" +#include "moses/WordsRange.h" namespace Moses { diff --git a/moses/FF/LexicalReordering/SparseReordering.cpp b/moses/FF/LexicalReordering/SparseReordering.cpp new file mode 100644 index 000000000..bc519eefc --- /dev/null +++ b/moses/FF/LexicalReordering/SparseReordering.cpp @@ -0,0 +1,252 @@ +#include + +#include "moses/FactorCollection.h" +#include "moses/InputPath.h" +#include "moses/Util.h" + +#include "util/exception.hh" + +#include "util/file_piece.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include "LexicalReordering.h" +#include "SparseReordering.h" + + +using namespace std; + +namespace Moses +{ + +const std::string& SparseReorderingFeatureKey::Name(const string& wordListId) { + static string kSep = "-"; + static string name; + ostringstream buf; + // type side position id word reotype + if (type == Phrase) { + buf << "phr"; + } else if (type == Stack) { + buf << "stk"; + } else if (type == Between) { + buf << "btn"; + } + buf << kSep; + if (side == Source) { + buf << "src"; + } else if (side == Target) { + buf << "tgt"; + } + buf << kSep; + if (position == First) { + buf << "first"; + } else if (position == Last) { + buf << "last"; + } + buf << kSep; + buf << wordListId; + buf << kSep; + if (isCluster) buf << "cluster_"; + buf << word->GetString(); + buf << kSep; + buf << reoType; + name = buf.str(); + return name; +} + +SparseReordering::SparseReordering(const map& config, const LexicalReordering* producer) + : m_producer(producer) +{ + static const string kSource= "source"; + static const string kTarget = "target"; + for (map::const_iterator i = config.begin(); i != config.end(); ++i) { + vector fields = Tokenize(i->first, "-"); + if (fields[0] == "words") { + UTIL_THROW_IF(!(fields.size() == 3), util::Exception, "Sparse reordering word list name should be sparse-words-(source|target)-"); + if (fields[1] == kSource) { + ReadWordList(i->second,fields[2], SparseReorderingFeatureKey::Source, &m_sourceWordLists); + } else if (fields[1] == kTarget) { + ReadWordList(i->second,fields[2],SparseReorderingFeatureKey::Target, &m_targetWordLists); + } else { + UTIL_THROW(util::Exception, "Sparse reordering requires source or target, not " << fields[1]); + } + } else if (fields[0] == "clusters") { + UTIL_THROW_IF(!(fields.size() == 3), util::Exception, "Sparse reordering cluster name should be sparse-clusters-(source|target)-"); + if (fields[1] == kSource) { + ReadClusterMap(i->second,fields[2], SparseReorderingFeatureKey::Source, &m_sourceClusterMaps); + } else if (fields[1] == kTarget) { + ReadClusterMap(i->second,fields[2],SparseReorderingFeatureKey::Target, &m_targetClusterMaps); + } else { + UTIL_THROW(util::Exception, "Sparse reordering requires source or target, not " << fields[1]); + } + + } else if (fields[0] == "phrase") { + m_usePhrase = true; + } else if (fields[0] == "stack") { + m_useStack = true; + } else if (fields[0] == "between") { + m_useBetween = true; + } else { + UTIL_THROW(util::Exception, "Unable to parse sparse reordering option: " << i->first); + } + } + +} + +void SparseReordering::PreCalculateFeatureNames(size_t index, const string& id, SparseReorderingFeatureKey::Side side, const Factor* factor, bool isCluster) { + for (size_t type = SparseReorderingFeatureKey::Stack; + type <= SparseReorderingFeatureKey::Between; ++type) { + for (size_t position = SparseReorderingFeatureKey::First; + position <= SparseReorderingFeatureKey::Last; ++position) { + for (int reoType = 0; reoType <= LexicalReorderingState::MAX; ++reoType) { + SparseReorderingFeatureKey key( + index, static_cast(type), factor, isCluster, + static_cast(position), side, reoType); + m_featureMap[key] = key.Name(id); + } + } + } +} + +void SparseReordering::ReadWordList(const string& filename, const string& id, SparseReorderingFeatureKey::Side side, vector* pWordLists) { + ifstream fh(filename.c_str()); + UTIL_THROW_IF(!fh, util::Exception, "Unable to open: " << filename); + string line; + pWordLists->push_back(WordList()); + pWordLists->back().first = id; + while (getline(fh,line)) { + //TODO: StringPiece + const Factor* factor = FactorCollection::Instance().AddFactor(line); + pWordLists->back().second.insert(factor); + PreCalculateFeatureNames(pWordLists->size()-1, id, side, factor, false); + + } +} + +void SparseReordering::ReadClusterMap(const string& filename, const string& id, SparseReorderingFeatureKey::Side side, vector* pClusterMaps) { + pClusterMaps->push_back(ClusterMap()); + pClusterMaps->back().first = id; + util::FilePiece file(filename.c_str()); + StringPiece line; + while (true) { + try { + line = file.ReadLine(); + } catch (const util::EndOfFileException &e) { + break; + } + util::TokenIter lineIter(line,util::SingleCharacter('\t')); + if (!lineIter) UTIL_THROW(util::Exception, "Malformed cluster line (missing word): '" << line << "'"); + const Factor* wordFactor = FactorCollection::Instance().AddFactor(*lineIter); + ++lineIter; + if (!lineIter) UTIL_THROW(util::Exception, "Malformed cluster line (missing cluster id): '" << line << "'"); + const Factor* idFactor = FactorCollection::Instance().AddFactor(*lineIter); + pClusterMaps->back().second[wordFactor] = idFactor; + PreCalculateFeatureNames(pClusterMaps->size()-1, id, side, idFactor, true); + } +} + +void SparseReordering::AddFeatures( + SparseReorderingFeatureKey::Type type, SparseReorderingFeatureKey::Side side, + const Word& word, SparseReorderingFeatureKey::Position position, + LexicalReorderingState::ReorderingType reoType, + ScoreComponentCollection* scores) const { + + const Factor* wordFactor = word.GetFactor(0); + + const vector* wordLists; + const vector* clusterMaps; + if (side == SparseReorderingFeatureKey::Source) { + wordLists = &m_sourceWordLists; + clusterMaps = &m_sourceClusterMaps; + } else { + wordLists = &m_targetWordLists; + clusterMaps = &m_targetClusterMaps; + } + + for (size_t id = 0; id < wordLists->size(); ++id) { + if ((*wordLists)[id].second.find(wordFactor) == (*wordLists)[id].second.end()) continue; + SparseReorderingFeatureKey key(id, type, wordFactor, false, position, side, reoType); + FeatureMap::const_iterator fmi = m_featureMap.find(key); + assert(fmi != m_featureMap.end()); + scores->PlusEquals(m_producer, fmi->second, 1.0); + } + + for (size_t id = 0; id < clusterMaps->size(); ++id) { + const ClusterMap& clusterMap = (*clusterMaps)[id]; + boost::unordered_map::const_iterator clusterIter + = clusterMap.second.find(wordFactor); + if (clusterIter != clusterMap.second.end()) { + SparseReorderingFeatureKey key(id, type, clusterIter->second, true, position, side, reoType); + FeatureMap::const_iterator fmi = m_featureMap.find(key); + assert(fmi != m_featureMap.end()); + scores->PlusEquals(m_producer, fmi->second, 1.0); + } + } + +} + +void SparseReordering::CopyScores( + const TranslationOption& currentOpt, + const TranslationOption* previousOpt, + const InputType& input, + LexicalReorderingState::ReorderingType reoType, + LexicalReorderingConfiguration::Direction direction, + ScoreComponentCollection* scores) const +{ + if (m_useBetween && direction == LexicalReorderingConfiguration::Backward && + (reoType == LexicalReorderingState::D || reoType == LexicalReorderingState::DL || + reoType == LexicalReorderingState::DR)) { + size_t gapStart, gapEnd; + const Sentence& sentence = dynamic_cast(input); + const WordsRange& currentRange = currentOpt.GetSourceWordsRange(); + if (previousOpt) { + const WordsRange& previousRange = previousOpt->GetSourceWordsRange(); + if (previousRange < currentRange) { + gapStart = previousRange.GetEndPos() + 1; + gapEnd = currentRange.GetStartPos(); + } else { + gapStart = currentRange.GetEndPos() + 1; + gapEnd = previousRange.GetStartPos(); + } + } else { + //start of sentence + gapStart = 0; + gapEnd = currentRange.GetStartPos(); + } + assert(gapStart < gapEnd); + for (size_t i = gapStart; i < gapEnd; ++i) { + AddFeatures(SparseReorderingFeatureKey::Between, + SparseReorderingFeatureKey::Source, sentence.GetWord(i), + SparseReorderingFeatureKey::First, reoType, scores); + } + } + //std::cerr << "SR " << topt << " " << reoType << " " << direction << std::endl; + //phrase (backward) + //stack (forward) + SparseReorderingFeatureKey::Type type; + if (direction == LexicalReorderingConfiguration::Forward) { + if (!m_useStack) return; + type = SparseReorderingFeatureKey::Stack; + } else if (direction == LexicalReorderingConfiguration::Backward) { + if (!m_usePhrase) return; + type = SparseReorderingFeatureKey::Phrase; + } else { + //Shouldn't be called for bidirectional + //keep compiler happy + type = SparseReorderingFeatureKey::Phrase; + assert(!"Shouldn't call CopyScores() with bidirectional direction"); + } + const Phrase& sourcePhrase = currentOpt.GetInputPath().GetPhrase(); + AddFeatures(type, SparseReorderingFeatureKey::Source, sourcePhrase.GetWord(0), + SparseReorderingFeatureKey::First, reoType, scores); + AddFeatures(type, SparseReorderingFeatureKey::Source, sourcePhrase.GetWord(sourcePhrase.GetSize()-1), SparseReorderingFeatureKey::Last, reoType, scores); + const Phrase& targetPhrase = currentOpt.GetTargetPhrase(); + AddFeatures(type, SparseReorderingFeatureKey::Target, targetPhrase.GetWord(0), + SparseReorderingFeatureKey::First, reoType, scores); + AddFeatures(type, SparseReorderingFeatureKey::Target, targetPhrase.GetWord(targetPhrase.GetSize()-1), SparseReorderingFeatureKey::Last, reoType, scores); + + +} + +} //namespace + diff --git a/moses/FF/LexicalReordering/SparseReordering.h b/moses/FF/LexicalReordering/SparseReordering.h new file mode 100644 index 000000000..e496daf94 --- /dev/null +++ b/moses/FF/LexicalReordering/SparseReordering.h @@ -0,0 +1,132 @@ +#ifndef moses_FF_LexicalReordering_SparseReordering_h +#define moses_FF_LexicalReordering_SparseReordering_h + +/** + * Sparse reordering features for phrase-based MT, following Cherry (NAACL, 2013) +**/ + + +#include +#include +#include +#include + +#include + +#include "util/murmur_hash.hh" +#include "util/pool.hh" +#include "util/string_piece.hh" + +#include "moses/ScoreComponentCollection.h" +#include "LexicalReorderingState.h" + +/** + Configuration of sparse reordering: + + The sparse reordering feature is configured using sparse-* configs in the lexical reordering line. + sparse-words-(source|target)-= -- Features which fire for the words in the list + sparse-clusters-(source|target)-= -- Features which fire for clusters in the list. Format + of cluster file TBD + sparse-phrase -- Add features which depend on the current phrase (backward) + sparse-stack -- Add features which depend on the previous phrase, or + top of stack. (forward) + sparse-between -- Add features which depend on words between previous phrase + (or top of stack) and current phrase. +**/ + +namespace Moses +{ + +/** + * Used to store pre-calculated feature names. +**/ +struct SparseReorderingFeatureKey { + size_t id; + enum Type {Stack, Phrase, Between} type; + const Factor* word; + bool isCluster; + enum Position {First, Last} position; + enum Side {Source, Target} side; + LexicalReorderingState::ReorderingType reoType; + + SparseReorderingFeatureKey(size_t id_, Type type_, const Factor* word_, bool isCluster_, + Position position_, Side side_, LexicalReorderingState::ReorderingType reoType_) + : id(id_), type(type_), word(word_), isCluster(isCluster_), + position(position_), side(side_), reoType(reoType_) + {} + + const std::string& Name(const std::string& wordListId) ; +}; + +struct HashSparseReorderingFeatureKey : public std::unary_function { + std::size_t operator()(const SparseReorderingFeatureKey& key) const { + //TODO: can we just hash the memory? + //not sure, there could be random padding + std::size_t seed = 0; + seed = util::MurmurHashNative(&key.id, sizeof(key.id), seed); + seed = util::MurmurHashNative(&key.type, sizeof(key.type), seed); + seed = util::MurmurHashNative(&key.word, sizeof(key.word), seed); + seed = util::MurmurHashNative(&key.isCluster, sizeof(key.isCluster), seed); + seed = util::MurmurHashNative(&key.position, sizeof(key.position), seed); + seed = util::MurmurHashNative(&key.side, sizeof(key.side), seed); + seed = util::MurmurHashNative(&key.reoType, sizeof(key.reoType), seed); + return seed; + } +}; + +struct EqualsSparseReorderingFeatureKey : + public std::binary_function { + bool operator()(const SparseReorderingFeatureKey& left, const SparseReorderingFeatureKey& right) const { + //TODO: Can we just compare the memory? + return left.id == right.id && left.type == right.type && left.word == right.word && + left.position == right.position && left.side == right.side && + left.reoType == right.reoType; + } +}; + +class SparseReordering +{ +public: + SparseReordering(const std::map& config, const LexicalReordering* producer); + + //If direction is backward the options will be different, for forward they will be the same + void CopyScores(const TranslationOption& currentOpt, + const TranslationOption* previousOpt, + const InputType& input, + LexicalReorderingState::ReorderingType reoType, + LexicalReorderingConfiguration::Direction direction, + ScoreComponentCollection* scores) const ; + +private: + const LexicalReordering* m_producer; + typedef std::pair > WordList; //id and list + std::vector m_sourceWordLists; + std::vector m_targetWordLists; + typedef std::pair > ClusterMap; //id and map + std::vector m_sourceClusterMaps; + std::vector m_targetClusterMaps; + bool m_usePhrase; + bool m_useBetween; + bool m_useStack; + typedef boost::unordered_map FeatureMap; + FeatureMap m_featureMap; + + void ReadWordList(const std::string& filename, const std::string& id, + SparseReorderingFeatureKey::Side side, std::vector* pWordLists); + void ReadClusterMap(const std::string& filename, const std::string& id, SparseReorderingFeatureKey::Side side, std::vector* pClusterMaps); + void PreCalculateFeatureNames(size_t index, const std::string& id, SparseReorderingFeatureKey::Side side, const Factor* factor, bool isCluster); + + void AddFeatures( + SparseReorderingFeatureKey::Type type, SparseReorderingFeatureKey::Side side, + const Word& word, SparseReorderingFeatureKey::Position position, + LexicalReorderingState::ReorderingType reoType, + ScoreComponentCollection* scores) const; + +}; + + + +} //namespace + + +#endif diff --git a/moses/FF/MaxSpanFreeNonTermSource.cpp b/moses/FF/MaxSpanFreeNonTermSource.cpp index 3951fdd27..9de582635 100644 --- a/moses/FF/MaxSpanFreeNonTermSource.cpp +++ b/moses/FF/MaxSpanFreeNonTermSource.cpp @@ -27,7 +27,7 @@ MaxSpanFreeNonTermSource::MaxSpanFreeNonTermSource(const std::string &line) m_glueTargetLHS.SetFactor(0, factor); } -void MaxSpanFreeNonTermSource::Evaluate(const Phrase &source +void MaxSpanFreeNonTermSource::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const @@ -35,7 +35,7 @@ void MaxSpanFreeNonTermSource::Evaluate(const Phrase &source targetPhrase.SetRuleSource(source); } -void MaxSpanFreeNonTermSource::Evaluate(const InputType &input +void MaxSpanFreeNonTermSource::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec diff --git a/moses/FF/MaxSpanFreeNonTermSource.h b/moses/FF/MaxSpanFreeNonTermSource.h index a9eec7b5e..973b374d8 100644 --- a/moses/FF/MaxSpanFreeNonTermSource.h +++ b/moses/FF/MaxSpanFreeNonTermSource.h @@ -15,23 +15,23 @@ public: virtual bool IsUseable(const FactorMask &mask) const { return true; } - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - virtual void Evaluate(const InputType &input + virtual void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const; - virtual void Evaluate(const Hypothesis& hypo, + virtual void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - virtual void EvaluateChart(const ChartHypothesis &hypo, + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/NieceTerminal.cpp b/moses/FF/NieceTerminal.cpp index eaef9f79e..b3a5f8f92 100644 --- a/moses/FF/NieceTerminal.cpp +++ b/moses/FF/NieceTerminal.cpp @@ -17,7 +17,15 @@ NieceTerminal::NieceTerminal(const std::string &line) ReadParameters(); } -void NieceTerminal::Evaluate(const Phrase &source +std::vector NieceTerminal::DefaultWeights() const +{ + UTIL_THROW_IF2(m_numScoreComponents != 1, + "NieceTerminal must only have 1 score"); + vector ret(1, 1); + return ret; +} + +void NieceTerminal::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const @@ -25,7 +33,7 @@ void NieceTerminal::Evaluate(const Phrase &source targetPhrase.SetRuleSource(source); } -void NieceTerminal::Evaluate(const InputType &input +void NieceTerminal::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -63,11 +71,11 @@ void NieceTerminal::Evaluate(const InputType &input } -void NieceTerminal::Evaluate(const Hypothesis& hypo, +void NieceTerminal::EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} -void NieceTerminal::EvaluateChart(const ChartHypothesis &hypo, +void NieceTerminal::EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/NieceTerminal.h b/moses/FF/NieceTerminal.h index 9b56e489f..7daf2963e 100644 --- a/moses/FF/NieceTerminal.h +++ b/moses/FF/NieceTerminal.h @@ -19,22 +19,23 @@ public: return true; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const; - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const; void SetParameter(const std::string& key, const std::string& value); + std::vector DefaultWeights() const; protected: bool m_hardConstraint; diff --git a/moses/FF/OSM-Feature/OpSequenceModel.cpp b/moses/FF/OSM-Feature/OpSequenceModel.cpp index dfa380a77..793942151 100644 --- a/moses/FF/OSM-Feature/OpSequenceModel.cpp +++ b/moses/FF/OSM-Feature/OpSequenceModel.cpp @@ -42,7 +42,7 @@ void OpSequenceModel::Load() -void OpSequenceModel:: Evaluate(const Phrase &source +void OpSequenceModel:: EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const @@ -87,7 +87,7 @@ void OpSequenceModel:: Evaluate(const Phrase &source } -FFState* OpSequenceModel::Evaluate( +FFState* OpSequenceModel::EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const @@ -194,7 +194,7 @@ FFState* OpSequenceModel::Evaluate( // return NULL; } -FFState* OpSequenceModel::EvaluateChart( +FFState* OpSequenceModel::EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const diff --git a/moses/FF/OSM-Feature/OpSequenceModel.h b/moses/FF/OSM-Feature/OpSequenceModel.h index 64cab3044..c4d26f98e 100644 --- a/moses/FF/OSM-Feature/OpSequenceModel.h +++ b/moses/FF/OSM-Feature/OpSequenceModel.h @@ -29,24 +29,24 @@ public: void readLanguageModel(const char *); void Load(); - FFState* Evaluate( + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - virtual FFState* EvaluateChart( + virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; diff --git a/moses/FF/PhraseBoundaryFeature.cpp b/moses/FF/PhraseBoundaryFeature.cpp index d82181b76..3fdcf27f9 100644 --- a/moses/FF/PhraseBoundaryFeature.cpp +++ b/moses/FF/PhraseBoundaryFeature.cpp @@ -66,7 +66,7 @@ void PhraseBoundaryFeature::AddFeatures( } -FFState* PhraseBoundaryFeature::Evaluate +FFState* PhraseBoundaryFeature::EvaluateWhenApplied (const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* scores) const { diff --git a/moses/FF/PhraseBoundaryFeature.h b/moses/FF/PhraseBoundaryFeature.h index fbafc6da9..e4c3ca3ba 100644 --- a/moses/FF/PhraseBoundaryFeature.h +++ b/moses/FF/PhraseBoundaryFeature.h @@ -44,23 +44,23 @@ public: virtual const FFState* EmptyHypothesisState(const InputType &) const; - virtual FFState* Evaluate(const Hypothesis& cur_hypo, const FFState* prev_state, + virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - virtual FFState* EvaluateChart( const ChartHypothesis& /* cur_hypo */, + virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID */, ScoreComponentCollection* ) const { throw std::logic_error("PhraseBoundaryState not supported in chart decoder, yet"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/PhraseLengthFeature.cpp b/moses/FF/PhraseLengthFeature.cpp index 43e0d1b2d..7850c374a 100644 --- a/moses/FF/PhraseLengthFeature.cpp +++ b/moses/FF/PhraseLengthFeature.cpp @@ -15,7 +15,7 @@ PhraseLengthFeature::PhraseLengthFeature(const std::string &line) ReadParameters(); } -void PhraseLengthFeature::Evaluate(const Phrase &source +void PhraseLengthFeature::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/PhraseLengthFeature.h b/moses/FF/PhraseLengthFeature.h index ba835f654..4976e2210 100644 --- a/moses/FF/PhraseLengthFeature.h +++ b/moses/FF/PhraseLengthFeature.h @@ -24,16 +24,16 @@ public: return true; } - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis& hypo, + void EvaluateWhenApplied(const ChartHypothesis& hypo, ScoreComponentCollection*) const { throw std::logic_error("PhraseLengthFeature not valid in chart decoder"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -41,7 +41,7 @@ public: , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; diff --git a/moses/FF/PhrasePairFeature.cpp b/moses/FF/PhrasePairFeature.cpp index 9277e19f2..f359b68f7 100644 --- a/moses/FF/PhrasePairFeature.cpp +++ b/moses/FF/PhrasePairFeature.cpp @@ -106,7 +106,7 @@ void PhrasePairFeature::Load() } } -void PhrasePairFeature::Evaluate( +void PhrasePairFeature::EvaluateWhenApplied( const Hypothesis& hypo, ScoreComponentCollection* accumulator) const { diff --git a/moses/FF/PhrasePairFeature.h b/moses/FF/PhrasePairFeature.h index 7790e9035..8bfac628d 100644 --- a/moses/FF/PhrasePairFeature.h +++ b/moses/FF/PhrasePairFeature.h @@ -37,22 +37,22 @@ public: bool IsUseable(const FactorMask &mask) const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const; - void EvaluateChart(const ChartHypothesis& hypo, + void EvaluateWhenApplied(const ChartHypothesis& hypo, ScoreComponentCollection*) const { throw std::logic_error("PhrasePairFeature not valid in chart decoder"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/PhrasePenalty.cpp b/moses/FF/PhrasePenalty.cpp index b3e493707..ddd21e491 100644 --- a/moses/FF/PhrasePenalty.cpp +++ b/moses/FF/PhrasePenalty.cpp @@ -10,7 +10,7 @@ PhrasePenalty::PhrasePenalty(const std::string &line) ReadParameters(); } -void PhrasePenalty::Evaluate(const Phrase &source +void PhrasePenalty::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/PhrasePenalty.h b/moses/FF/PhrasePenalty.h index a4014abf1..f822e583b 100644 --- a/moses/FF/PhrasePenalty.h +++ b/moses/FF/PhrasePenalty.h @@ -14,19 +14,19 @@ public: return true; } - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec diff --git a/moses/FF/ReferenceComparison.h b/moses/FF/ReferenceComparison.h index 8b0341fd6..62cf15ced 100644 --- a/moses/FF/ReferenceComparison.h +++ b/moses/FF/ReferenceComparison.h @@ -15,13 +15,13 @@ public: virtual bool IsUseable(const FactorMask &mask) const { return true; } - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} - virtual void Evaluate(const InputType &input + virtual void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -29,11 +29,11 @@ public: , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - virtual void Evaluate(const Hypothesis& hypo, + virtual void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - virtual void EvaluateChart(const ChartHypothesis &hypo, + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/RuleScope.cpp b/moses/FF/RuleScope.cpp index e949c3337..ed329c4ca 100644 --- a/moses/FF/RuleScope.cpp +++ b/moses/FF/RuleScope.cpp @@ -16,7 +16,7 @@ bool IsAmbiguous(const Word &word, bool sourceSyntax) return word.IsNonTerminal() && (!sourceSyntax || word == inputDefaultNonTerminal); } -void RuleScope::Evaluate(const Phrase &source +void RuleScope::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/RuleScope.h b/moses/FF/RuleScope.h index 4ac10c804..a2c9e06f3 100644 --- a/moses/FF/RuleScope.h +++ b/moses/FF/RuleScope.h @@ -14,12 +14,12 @@ public: virtual bool IsUseable(const FactorMask &mask) const { return true; } - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - virtual void Evaluate(const InputType &input + virtual void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -27,11 +27,11 @@ public: , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - virtual void Evaluate(const Hypothesis& hypo, + virtual void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - virtual void EvaluateChart(const ChartHypothesis &hypo, + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/SetSourcePhrase.cpp b/moses/FF/SetSourcePhrase.cpp index 0a2eaa4cb..f89683f28 100644 --- a/moses/FF/SetSourcePhrase.cpp +++ b/moses/FF/SetSourcePhrase.cpp @@ -10,7 +10,7 @@ SetSourcePhrase::SetSourcePhrase(const std::string &line) ReadParameters(); } -void SetSourcePhrase::Evaluate(const Phrase &source +void SetSourcePhrase::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/SetSourcePhrase.h b/moses/FF/SetSourcePhrase.h index 0d7ad2ade..81f293dde 100644 --- a/moses/FF/SetSourcePhrase.h +++ b/moses/FF/SetSourcePhrase.h @@ -14,12 +14,12 @@ public: virtual bool IsUseable(const FactorMask &mask) const { return true; } - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - virtual void Evaluate(const InputType &input + virtual void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -27,11 +27,11 @@ public: , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - virtual void Evaluate(const Hypothesis& hypo, + virtual void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - virtual void EvaluateChart(const ChartHypothesis &hypo, + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/SkeletonStatefulFF.cpp b/moses/FF/SkeletonStatefulFF.cpp index 44771e646..fe81aeeae 100644 --- a/moses/FF/SkeletonStatefulFF.cpp +++ b/moses/FF/SkeletonStatefulFF.cpp @@ -16,13 +16,20 @@ int SkeletonState::Compare(const FFState& other) const return (m_targetLen < otherState.m_targetLen) ? -1 : +1; } -void SkeletonStatefulFF::Evaluate(const Phrase &source +//////////////////////////////////////////////////////////////// +SkeletonStatefulFF::SkeletonStatefulFF(const std::string &line) + :StatefulFeatureFunction(3, line) +{ + ReadParameters(); +} + +void SkeletonStatefulFF::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {} -void SkeletonStatefulFF::Evaluate(const InputType &input +void SkeletonStatefulFF::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec @@ -30,7 +37,7 @@ void SkeletonStatefulFF::Evaluate(const InputType &input , ScoreComponentCollection *estimatedFutureScore) const {} -FFState* SkeletonStatefulFF::Evaluate( +FFState* SkeletonStatefulFF::EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const @@ -49,7 +56,7 @@ FFState* SkeletonStatefulFF::Evaluate( return new SkeletonState(0); } -FFState* SkeletonStatefulFF::EvaluateChart( +FFState* SkeletonStatefulFF::EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const @@ -57,6 +64,14 @@ FFState* SkeletonStatefulFF::EvaluateChart( return new SkeletonState(0); } +void SkeletonStatefulFF::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "arg") { + // set value here + } else { + StatefulFeatureFunction::SetParameter(key, value); + } +} } diff --git a/moses/FF/SkeletonStatefulFF.h b/moses/FF/SkeletonStatefulFF.h index 1f2baa92b..6fa26803e 100644 --- a/moses/FF/SkeletonStatefulFF.h +++ b/moses/FF/SkeletonStatefulFF.h @@ -21,9 +21,7 @@ public: class SkeletonStatefulFF : public StatefulFeatureFunction { public: - SkeletonStatefulFF(const std::string &line) - :StatefulFeatureFunction(3, line) - {} + SkeletonStatefulFF(const std::string &line); bool IsUseable(const FactorMask &mask) const { return true; @@ -32,25 +30,27 @@ public: return new SkeletonState(0); } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const; - FFState* Evaluate( + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - FFState* EvaluateChart( + FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const; + void SetParameter(const std::string& key, const std::string& value); + }; diff --git a/moses/FF/SkeletonStatelessFF.cpp b/moses/FF/SkeletonStatelessFF.cpp index 0ef8570ee..80c7d130e 100644 --- a/moses/FF/SkeletonStatelessFF.cpp +++ b/moses/FF/SkeletonStatelessFF.cpp @@ -1,12 +1,19 @@ +#include #include "SkeletonStatelessFF.h" #include "moses/ScoreComponentCollection.h" -#include +#include "moses/TargetPhrase.h" using namespace std; namespace Moses { -void SkeletonStatelessFF::Evaluate(const Phrase &source +SkeletonStatelessFF::SkeletonStatelessFF(const std::string &line) + :StatelessFeatureFunction(2, line) +{ + ReadParameters(); +} + +void SkeletonStatelessFF::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const @@ -22,21 +29,37 @@ void SkeletonStatelessFF::Evaluate(const Phrase &source } -void SkeletonStatelessFF::Evaluate(const InputType &input +void SkeletonStatelessFF::EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore) const -{} - -void SkeletonStatelessFF::Evaluate(const Hypothesis& hypo, - ScoreComponentCollection* accumulator) const -{} - -void SkeletonStatelessFF::EvaluateChart(const ChartHypothesis &hypo, - ScoreComponentCollection* accumulator) const -{} +{ + if (targetPhrase.GetNumNonTerminals()) { + vector newScores(m_numScoreComponents); + newScores[0] = - std::numeric_limits::infinity(); + scoreBreakdown.PlusEquals(this, newScores); + } + +} + +void SkeletonStatelessFF::EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const +{} + +void SkeletonStatelessFF::EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const +{} + +void SkeletonStatelessFF::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "arg") { + // set value here + } else { + StatelessFeatureFunction::SetParameter(key, value); + } +} } diff --git a/moses/FF/SkeletonStatelessFF.h b/moses/FF/SkeletonStatelessFF.h index 6aac207f4..520ec1405 100644 --- a/moses/FF/SkeletonStatelessFF.h +++ b/moses/FF/SkeletonStatelessFF.h @@ -9,29 +9,29 @@ namespace Moses class SkeletonStatelessFF : public StatelessFeatureFunction { public: - SkeletonStatelessFF(const std::string &line) - :StatelessFeatureFunction(2, line) - {} + SkeletonStatelessFF(const std::string &line); bool IsUseable(const FactorMask &mask) const { return true; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const; - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const; + void SetParameter(const std::string& key, const std::string& value); + }; } diff --git a/moses/FF/SoftMatchingFeature.cpp b/moses/FF/SoftMatchingFeature.cpp index 017e551c4..0475547da 100644 --- a/moses/FF/SoftMatchingFeature.cpp +++ b/moses/FF/SoftMatchingFeature.cpp @@ -61,7 +61,7 @@ bool SoftMatchingFeature::Load(const std::string& filePath) return true; } -void SoftMatchingFeature::EvaluateChart(const ChartHypothesis& hypo, +void SoftMatchingFeature::EvaluateWhenApplied(const ChartHypothesis& hypo, ScoreComponentCollection* accumulator) const { diff --git a/moses/FF/SoftMatchingFeature.h b/moses/FF/SoftMatchingFeature.h index b823c2426..ff923ea08 100644 --- a/moses/FF/SoftMatchingFeature.h +++ b/moses/FF/SoftMatchingFeature.h @@ -19,20 +19,20 @@ public: return true; } - virtual void EvaluateChart(const ChartHypothesis& hypo, + virtual void EvaluateWhenApplied(const ChartHypothesis& hypo, ScoreComponentCollection* accumulator) const; - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {}; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {}; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {}; bool Load(const std::string &filePath); diff --git a/moses/FF/SourceGHKMTreeInputMatchFeature.cpp b/moses/FF/SourceGHKMTreeInputMatchFeature.cpp new file mode 100644 index 000000000..38238b10c --- /dev/null +++ b/moses/FF/SourceGHKMTreeInputMatchFeature.cpp @@ -0,0 +1,67 @@ +#include +#include +#include +#include "SourceGHKMTreeInputMatchFeature.h" +#include "moses/StaticData.h" +#include "moses/InputFileStream.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/Hypothesis.h" +#include "moses/ChartHypothesis.h" +#include "moses/Factor.h" +#include "moses/FactorCollection.h" +#include "moses/InputPath.h" +#include "moses/TreeInput.h" + + +using namespace std; + +namespace Moses +{ + +SourceGHKMTreeInputMatchFeature::SourceGHKMTreeInputMatchFeature(const std::string &line) + : StatelessFeatureFunction(2, line) +{ + std::cerr << GetScoreProducerDescription() << "Initializing feature..."; + ReadParameters(); + std::cerr << " Done." << std::endl; +} + +void SourceGHKMTreeInputMatchFeature::SetParameter(const std::string& key, const std::string& value) +{ + UTIL_THROW(util::Exception, GetScoreProducerDescription() << ": Unknown parameter " << key << "=" << value); +} + +// assumes that source-side syntax labels are stored in the target non-terminal field of the rules +void SourceGHKMTreeInputMatchFeature::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore) const +{ + const WordsRange& wordsRange = inputPath.GetWordsRange(); + size_t startPos = wordsRange.GetStartPos(); + size_t endPos = wordsRange.GetEndPos(); + const TreeInput& treeInput = static_cast(input); + const NonTerminalSet& treeInputLabels = treeInput.GetLabelSet(startPos,endPos); + const Word& lhsLabel = targetPhrase.GetTargetLHS(); + + const StaticData& staticData = StaticData::Instance(); + const Word& outputDefaultNonTerminal = staticData.GetOutputDefaultNonTerminal(); + + std::vector newScores(m_numScoreComponents,0.0); // m_numScoreComponents == 2 // first fires for matches, second for mismatches + + if ( (treeInputLabels.find(lhsLabel) != treeInputLabels.end()) && (lhsLabel != outputDefaultNonTerminal) ) { + // match + newScores[0] = 1.0; + } else { + // mismatch + newScores[1] = 1.0; + } + + scoreBreakdown.PlusEquals(this, newScores); +} + + +} + diff --git a/moses/FF/SourceGHKMTreeInputMatchFeature.h b/moses/FF/SourceGHKMTreeInputMatchFeature.h new file mode 100644 index 000000000..743871b1c --- /dev/null +++ b/moses/FF/SourceGHKMTreeInputMatchFeature.h @@ -0,0 +1,42 @@ +#pragma once + +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +// assumes that source-side syntax labels are stored in the target non-terminal field of the rules +class SourceGHKMTreeInputMatchFeature : public StatelessFeatureFunction +{ +public: + SourceGHKMTreeInputMatchFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void SetParameter(const std::string& key, const std::string& value); + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const {}; + + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const; + + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const {}; + + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const {}; + +}; + + +} + diff --git a/moses/FF/SourceWordDeletionFeature.cpp b/moses/FF/SourceWordDeletionFeature.cpp index 101e40579..e5167b93b 100644 --- a/moses/FF/SourceWordDeletionFeature.cpp +++ b/moses/FF/SourceWordDeletionFeature.cpp @@ -63,7 +63,7 @@ bool SourceWordDeletionFeature::IsUseable(const FactorMask &mask) const return ret; } -void SourceWordDeletionFeature::Evaluate(const Phrase &source +void SourceWordDeletionFeature::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/SourceWordDeletionFeature.h b/moses/FF/SourceWordDeletionFeature.h index 9b04476af..8211ef0ca 100644 --- a/moses/FF/SourceWordDeletionFeature.h +++ b/moses/FF/SourceWordDeletionFeature.h @@ -28,21 +28,21 @@ public: bool IsUseable(const FactorMask &mask) const; - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/SpanLength.cpp b/moses/FF/SpanLength.cpp new file mode 100644 index 000000000..7a7c87be8 --- /dev/null +++ b/moses/FF/SpanLength.cpp @@ -0,0 +1,93 @@ +#include +#include "SpanLength.h" +#include "moses/StaticData.h" +#include "moses/Word.h" +#include "moses/ChartCellLabel.h" +#include "moses/WordsRange.h" +#include "moses/StackVec.h" +#include "moses/TargetPhrase.h" +#include "moses/PP/PhraseProperty.h" +#include "moses/PP/SpanLengthPhraseProperty.h" + +using namespace std; + +namespace Moses +{ +SpanLength::SpanLength(const std::string &line) +:StatelessFeatureFunction(1, line) +,m_smoothingMethod(None) +,m_const(0) +{ + ReadParameters(); +} + +void SpanLength::EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const +{ + targetPhrase.SetRuleSource(source); +} + +void SpanLength::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore) const +{ + assert(stackVec); + + const PhraseProperty *property = targetPhrase.GetProperty("SpanLength"); + if (property == NULL) { + return; + } + + const SpanLengthPhraseProperty *slProp = static_cast(property); + + const Phrase *ruleSource = targetPhrase.GetRuleSource(); + assert(ruleSource); + + float score = 0; + for (size_t i = 0; i < stackVec->size(); ++i) { + const ChartCellLabel &cell = *stackVec->at(i); + const WordsRange &ntRange = cell.GetCoverage(); + size_t sourceWidth = ntRange.GetNumWordsCovered(); + float prob = slProp->GetProb(i, sourceWidth, m_const); + score += TransformScore(prob); + } + + if (score < -100.0f) { + float weight = StaticData::Instance().GetWeight(this); + if (weight < 0) { + score = -100; + } + } + + scoreBreakdown.PlusEquals(this, score); + +} + +void SpanLength::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "smoothing") { + if (value == "plus-constant") { + m_smoothingMethod = PlusConst; + } + else if (value == "none") { + m_smoothingMethod = None; + } + else { + UTIL_THROW(util::Exception, "Unknown smoothing type " << value); + } + } + else if (key == "constant") { + m_const = Scan(value); + } + else { + StatelessFeatureFunction::SetParameter(key, value); + } +} + +} + diff --git a/moses/FF/SpanLength.h b/moses/FF/SpanLength.h new file mode 100644 index 000000000..dc5564fcd --- /dev/null +++ b/moses/FF/SpanLength.h @@ -0,0 +1,52 @@ +#pragma once +#include +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +// Rule Scope - not quite completely implemented yet +class SpanLength : public StatelessFeatureFunction +{ +public: + SpanLength(const std::string &line); + + virtual bool IsUseable(const FactorMask &mask) const + { return true; } + + virtual void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const; + + virtual void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const; + + + virtual void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const + {} + + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const + {} + + void SetParameter(const std::string& key, const std::string& value); + +protected: + enum SmoothingMethod + { + None, + PlusConst, + }; + SmoothingMethod m_smoothingMethod; + + float m_const; +}; + +} + diff --git a/moses/FF/SparseHieroReorderingFeature.cpp b/moses/FF/SparseHieroReorderingFeature.cpp new file mode 100644 index 000000000..0c6ac4767 --- /dev/null +++ b/moses/FF/SparseHieroReorderingFeature.cpp @@ -0,0 +1,222 @@ +#include + +#include "moses/ChartHypothesis.h" +#include "moses/ChartManager.h" +#include "moses/FactorCollection.h" +#include "moses/Sentence.h" + +#include "util/exception.hh" + +#include "SparseHieroReorderingFeature.h" + +using namespace std; + +namespace Moses +{ + +SparseHieroReorderingFeature::SparseHieroReorderingFeature(const std::string &line) + :StatelessFeatureFunction(0, line), + m_type(SourceCombined), + m_sourceFactor(0), + m_targetFactor(0), + m_sourceVocabFile(""), + m_targetVocabFile("") +{ + + /* + Configuration of features. + factor - Which factor should it apply to + type - what type of sparse reordering feature. e.g. block (modelled on Matthias + Huck's EAMT 2012 features) + word - which words to include, e.g. src_bdry, src_all, tgt_bdry , ... + vocab - vocab file to limit it to + orientation - e.g. lr, etc. + */ + cerr << "Constructing a Sparse Reordering feature" << endl; + ReadParameters(); + m_otherFactor = FactorCollection::Instance().AddFactor("##OTHER##"); + LoadVocabulary(m_sourceVocabFile, m_sourceVocab); + LoadVocabulary(m_targetVocabFile, m_targetVocab); +} + +void SparseHieroReorderingFeature::SetParameter(const std::string& key, const std::string& value) { + if (key == "input-factor") { + m_sourceFactor = Scan(value); + } else if (key == "output-factor") { + m_targetFactor = Scan(value); + } else if (key == "input-vocab-file") { + m_sourceVocabFile = value; + } else if (key == "output-vocab-file") { + m_targetVocabFile = value; + } else if (key == "type") { + if (value == "SourceCombined") { + m_type = SourceCombined; + } else if (value == "SourceLeft") { + m_type = SourceLeft; + } else if (value == "SourceRight") { + m_type = SourceRight; + } else { + UTIL_THROW(util::Exception, "Unknown sparse reordering type " << value); + } + } else { + FeatureFunction::SetParameter(key, value); + } +} + +void SparseHieroReorderingFeature::LoadVocabulary(const std::string& filename, Vocab& vocab) +{ + if (filename.empty()) return; + ifstream in(filename.c_str()); + UTIL_THROW_IF(!in, util::Exception, "Unable to open vocab file: " << filename); + string line; + while(getline(in,line)) { + vocab.insert(FactorCollection::Instance().AddFactor(line)); + } + in.close(); +} + +const Factor* SparseHieroReorderingFeature::GetFactor(const Word& word, const Vocab& vocab, FactorType factorType) const { + const Factor* factor = word.GetFactor(factorType); + if (vocab.size() && vocab.find(factor) == vocab.end()) return m_otherFactor; + return factor; +} + +void SparseHieroReorderingFeature::EvaluateWhenApplied( + const ChartHypothesis& cur_hypo , + ScoreComponentCollection* accumulator) const +{ + // get index map for underlying hypotheses + //const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = + // cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().GetNonTermIndexMap(); + + //The Huck features. For a rule with source side: + // abXcdXef + //We first have to split into blocks: + // ab X cd X ef + //Then we extract features based in the boundary words of the neighbouring blocks + //For the block pair, we use the right word of the left block, and the left + //word of the right block. + + //Need to get blocks, and their alignment. Each block has a word range (on the + // on the source), a non-terminal flag, and a set of alignment points in the target phrase + + //We need to be able to map source word position to target word position, as + //much as possible (don't need interior of non-terminals). The alignment info + //objects just give us the mappings between *rule* positions. So if we can + //map source word position to source rule position, and target rule position + //to target word position, then we can map right through. + + size_t sourceStart = cur_hypo.GetCurrSourceRange().GetStartPos(); + size_t sourceSize = cur_hypo.GetCurrSourceRange().GetNumWordsCovered(); + + vector sourceNTSpans; + for (size_t prevHypoId = 0; prevHypoId < cur_hypo.GetPrevHypos().size(); ++prevHypoId) { + sourceNTSpans.push_back(cur_hypo.GetPrevHypo(prevHypoId)->GetCurrSourceRange()); + } + //put in source order. Is this necessary? + sort(sourceNTSpans.begin(), sourceNTSpans.end()); + //cerr << "Source NTs: "; + //for (size_t i = 0; i < sourceNTSpans.size(); ++i) cerr << sourceNTSpans[i] << " "; + //cerr << endl; + + typedef pair Block;//flag indicates NT + vector sourceBlocks; + sourceBlocks.push_back(Block(cur_hypo.GetCurrSourceRange(),false)); + for (vector::const_iterator i = sourceNTSpans.begin(); + i != sourceNTSpans.end(); ++i) { + const WordsRange& prevHypoRange = *i; + Block lastBlock = sourceBlocks.back(); + sourceBlocks.pop_back(); + //split this range into before NT, NT and after NT + if (prevHypoRange.GetStartPos() > lastBlock.first.GetStartPos()) { + sourceBlocks.push_back(Block(WordsRange(lastBlock.first.GetStartPos(),prevHypoRange.GetStartPos()-1),false)); + } + sourceBlocks.push_back(Block(prevHypoRange,true)); + if (prevHypoRange.GetEndPos() < lastBlock.first.GetEndPos()) { + sourceBlocks.push_back(Block(WordsRange(prevHypoRange.GetEndPos()+1,lastBlock.first.GetEndPos()), false)); + } + } + /* + cerr << "Source Blocks: "; + for (size_t i = 0; i < sourceBlocks.size(); ++i) cerr << sourceBlocks[i].first << " " + << (sourceBlocks[i].second ? "NT" : "T") << " "; + cerr << endl; + */ + + //Mapping from source word to target rule position + vector sourceWordToTargetRulePos(sourceSize); + map alignMap; + alignMap.insert( + cur_hypo.GetCurrTargetPhrase().GetAlignTerm().begin(), + cur_hypo.GetCurrTargetPhrase().GetAlignTerm().end()); + alignMap.insert( + cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().begin(), + cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm().end()); + //vector alignMapTerm = cur_hypo.GetCurrTargetPhrase().GetAlignNonTerm() + size_t sourceRulePos = 0; + //cerr << "SW->RP "; + for (vector::const_iterator sourceBlockIt = sourceBlocks.begin(); + sourceBlockIt != sourceBlocks.end(); ++sourceBlockIt) { + for (size_t sourceWordPos = sourceBlockIt->first.GetStartPos(); + sourceWordPos <= sourceBlockIt->first.GetEndPos(); ++sourceWordPos) { + sourceWordToTargetRulePos[sourceWordPos - sourceStart] = alignMap[sourceRulePos]; + // cerr << sourceWordPos - sourceStart << "-" << alignMap[sourceRulePos] << " "; + if (! sourceBlockIt->second) { + //T + ++sourceRulePos; + } + } + if ( sourceBlockIt->second) { + //NT + ++sourceRulePos; + } + } + //cerr << endl; + + //Iterate through block pairs + const Sentence& sentence = + dynamic_cast(cur_hypo.GetManager().GetSource()); + //const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); + for (size_t i = 0; i < sourceBlocks.size()-1; ++i) { + Block& leftSourceBlock = sourceBlocks[i]; + Block& rightSourceBlock = sourceBlocks[i+1]; + size_t sourceLeftBoundaryPos = leftSourceBlock.first.GetEndPos(); + size_t sourceRightBoundaryPos = rightSourceBlock.first.GetStartPos(); + const Word& sourceLeftBoundaryWord = sentence.GetWord(sourceLeftBoundaryPos); + const Word& sourceRightBoundaryWord = sentence.GetWord(sourceRightBoundaryPos); + sourceLeftBoundaryPos -= sourceStart; + sourceRightBoundaryPos -= sourceStart; + + // Need to figure out where these map to on the target. + size_t targetLeftRulePos = + sourceWordToTargetRulePos[sourceLeftBoundaryPos]; + size_t targetRightRulePos = + sourceWordToTargetRulePos[sourceRightBoundaryPos]; + + bool isMonotone = true; + if ((sourceLeftBoundaryPos < sourceRightBoundaryPos && + targetLeftRulePos > targetRightRulePos) || + ((sourceLeftBoundaryPos > sourceRightBoundaryPos && + targetLeftRulePos < targetRightRulePos))) + { + isMonotone = false; + } + stringstream buf; + buf << "h_"; //sparse reordering, Huck + if (m_type == SourceLeft || m_type == SourceCombined) { + buf << GetFactor(sourceLeftBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString(); + buf << "_"; + } + if (m_type == SourceRight || m_type == SourceCombined) { + buf << GetFactor(sourceRightBoundaryWord,m_sourceVocab,m_sourceFactor)->GetString(); + buf << "_"; + } + buf << (isMonotone ? "M" : "S"); + accumulator->PlusEquals(this,buf.str(), 1); + } +// cerr << endl; +} + + +} + diff --git a/moses/FF/SparseHieroReorderingFeature.h b/moses/FF/SparseHieroReorderingFeature.h new file mode 100644 index 000000000..d631fdec1 --- /dev/null +++ b/moses/FF/SparseHieroReorderingFeature.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include + +#include + +#include "moses/Factor.h" +#include "moses/Sentence.h" + +#include "StatelessFeatureFunction.h" +#include "FFState.h" + +namespace Moses +{ + +class SparseHieroReorderingFeature : public StatelessFeatureFunction +{ +public: + enum Type { + SourceCombined, + SourceLeft, + SourceRight + }; + + SparseHieroReorderingFeature(const std::string &line); + + bool IsUseable(const FactorMask &mask) const + { return true; } + + void SetParameter(const std::string& key, const std::string& value); + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const + {} + virtual void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const + {} + + virtual void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const + {} + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const; + + +private: + + typedef boost::unordered_set Vocab; + + void AddNonTerminalPairFeatures( + const Sentence& sentence, const WordsRange& nt1, const WordsRange& nt2, + bool isMonotone, ScoreComponentCollection* accumulator) const; + + void LoadVocabulary(const std::string& filename, Vocab& vocab); + const Factor* GetFactor(const Word& word, const Vocab& vocab, FactorType factor) const; + + Type m_type; + FactorType m_sourceFactor; + FactorType m_targetFactor; + std::string m_sourceVocabFile; + std::string m_targetVocabFile; + + const Factor* m_otherFactor; + + Vocab m_sourceVocab; + Vocab m_targetVocab; + +}; + + +} + diff --git a/moses/FF/SparseHieroReorderingFeatureTest.cpp b/moses/FF/SparseHieroReorderingFeatureTest.cpp new file mode 100644 index 000000000..f05355df9 --- /dev/null +++ b/moses/FF/SparseHieroReorderingFeatureTest.cpp @@ -0,0 +1,36 @@ +/*********************************************************************** +Moses - factored phrase-based language decoder +Copyright (C) 2013- University of Edinburgh + +This library is free software; you can redistribute it and/or +modify it under the terms of the GNU Lesser General Public +License as published by the Free Software Foundation; either +version 2.1 of the License, or (at your option) any later version. + +This library is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public +License along with this library; if not, write to the Free Software +Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +***********************************************************************/ +#include + +#include + +#include "SparseHieroReorderingFeature.h" + +using namespace Moses; +using namespace std; + +BOOST_AUTO_TEST_SUITE(shrf) + +BOOST_AUTO_TEST_CASE(lexical_rule) +{ + SparseHieroReorderingFeature feature("name=shrf"); + +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/moses/FF/StatefulFeatureFunction.h b/moses/FF/StatefulFeatureFunction.h index 75b46d827..86bed04ee 100644 --- a/moses/FF/StatefulFeatureFunction.h +++ b/moses/FF/StatefulFeatureFunction.h @@ -29,12 +29,12 @@ public: * hypothesis, you should store it in an FFState object which will be passed * in as prev_state. If you don't do this, you will get in trouble. */ - virtual FFState* Evaluate( + virtual FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const = 0; - virtual FFState* EvaluateChart( + virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const = 0; diff --git a/moses/FF/StatelessFeatureFunction.h b/moses/FF/StatelessFeatureFunction.h index fde740115..94029f882 100644 --- a/moses/FF/StatelessFeatureFunction.h +++ b/moses/FF/StatelessFeatureFunction.h @@ -23,13 +23,13 @@ public: /** * This should be implemented for features that apply to phrase-based models. **/ - virtual void Evaluate(const Hypothesis& hypo, + virtual void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const = 0; /** * Same for chart-based features. **/ - virtual void EvaluateChart(const ChartHypothesis &hypo, + virtual void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const = 0; virtual bool IsStateless() const { diff --git a/moses/FF/SyntaxRHS.cpp b/moses/FF/SyntaxRHS.cpp new file mode 100644 index 000000000..5168b72d7 --- /dev/null +++ b/moses/FF/SyntaxRHS.cpp @@ -0,0 +1,54 @@ +#include +#include "SyntaxRHS.h" +#include "moses/ScoreComponentCollection.h" +#include "moses/TargetPhrase.h" +#include "moses/StackVec.h" + +using namespace std; + +namespace Moses +{ +SyntaxRHS::SyntaxRHS(const std::string &line) +:StatelessFeatureFunction(1, line) +{ + ReadParameters(); +} + +void SyntaxRHS::EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const +{ +} + +void SyntaxRHS::EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore) const +{ + assert(stackVec); + for (size_t i = 0; i < stackVec->size(); ++i) { + const ChartCellLabel &cell = *stackVec->at(i); + + } + + if (targetPhrase.GetNumNonTerminals()) { + vector newScores(m_numScoreComponents); + newScores[0] = - std::numeric_limits::infinity(); + scoreBreakdown.PlusEquals(this, newScores); + } + +} + +void SyntaxRHS::EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const +{} + +void SyntaxRHS::EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const +{} + +} + diff --git a/moses/FF/SyntaxRHS.h b/moses/FF/SyntaxRHS.h new file mode 100644 index 000000000..4b9214995 --- /dev/null +++ b/moses/FF/SyntaxRHS.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include "StatelessFeatureFunction.h" + +namespace Moses +{ + +class SyntaxRHS : public StatelessFeatureFunction +{ +public: + SyntaxRHS(const std::string &line); + + bool IsUseable(const FactorMask &mask) const { + return true; + } + + void EvaluateInIsolation(const Phrase &source + , const TargetPhrase &targetPhrase + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection &estimatedFutureScore) const; + void EvaluateWithSourceContext(const InputType &input + , const InputPath &inputPath + , const TargetPhrase &targetPhrase + , const StackVec *stackVec + , ScoreComponentCollection &scoreBreakdown + , ScoreComponentCollection *estimatedFutureScore = NULL) const; + void EvaluateWhenApplied(const Hypothesis& hypo, + ScoreComponentCollection* accumulator) const; + void EvaluateWhenApplied(const ChartHypothesis &hypo, + ScoreComponentCollection* accumulator) const; + +}; + +} + diff --git a/moses/FF/TargetBigramFeature.cpp b/moses/FF/TargetBigramFeature.cpp index 104f986e7..f1da62b7d 100644 --- a/moses/FF/TargetBigramFeature.cpp +++ b/moses/FF/TargetBigramFeature.cpp @@ -64,7 +64,7 @@ const FFState* TargetBigramFeature::EmptyHypothesisState(const InputType &/*inpu return new TargetBigramState(m_bos); } -FFState* TargetBigramFeature::Evaluate(const Hypothesis& cur_hypo, +FFState* TargetBigramFeature::EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const { diff --git a/moses/FF/TargetBigramFeature.h b/moses/FF/TargetBigramFeature.h index fe2500ad2..c63f3caa4 100644 --- a/moses/FF/TargetBigramFeature.h +++ b/moses/FF/TargetBigramFeature.h @@ -39,22 +39,22 @@ public: virtual const FFState* EmptyHypothesisState(const InputType &input) const; - virtual FFState* Evaluate(const Hypothesis& cur_hypo, const FFState* prev_state, + virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - virtual FFState* EvaluateChart( const ChartHypothesis& /* cur_hypo */, + virtual FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID */, ScoreComponentCollection* ) const { throw std::logic_error("TargetBigramFeature not valid in chart decoder"); } - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/TargetNgramFeature.cpp b/moses/FF/TargetNgramFeature.cpp index 9ea1ccadf..a43410990 100644 --- a/moses/FF/TargetNgramFeature.cpp +++ b/moses/FF/TargetNgramFeature.cpp @@ -82,6 +82,7 @@ void TargetNgramFeature::Load() m_vocab.insert(EOS_); while (getline(inFile, line)) { m_vocab.insert(line); + cerr << "ADD TO VOCAB: '" << line << "'" << endl; } inFile.close(); @@ -94,7 +95,7 @@ const FFState* TargetNgramFeature::EmptyHypothesisState(const InputType &/*input return new TargetNgramState(bos); } -FFState* TargetNgramFeature::Evaluate(const Hypothesis& cur_hypo, +FFState* TargetNgramFeature::EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const { @@ -119,7 +120,9 @@ FFState* TargetNgramFeature::Evaluate(const Hypothesis& cur_hypo, // const string& curr_w = targetPhrase.GetWord(i).GetFactor(m_factorType)->GetString(); const StringPiece curr_w = targetPhrase.GetWord(i).GetString(m_factorType); + //cerr << "CHECK WORD '" << curr_w << "'" << endl; if (m_vocab.size() && (FindStringPiece(m_vocab, curr_w) == m_vocab.end())) continue; // skip ngrams + //cerr << "ALLOWED WORD '" << curr_w << "'" << endl; if (n > 1) { // can we build an ngram at this position? (" this" --> cannot build 3gram at this position) @@ -154,6 +157,7 @@ FFState* TargetNgramFeature::Evaluate(const Hypothesis& cur_hypo, if (!skip) { curr_ngram << curr_w; + //cerr << "SCORE '" << curr_ngram.str() << "'" << endl; accumulator->PlusEquals(this,curr_ngram.str(),1); } curr_ngram.str(""); @@ -203,7 +207,7 @@ void TargetNgramFeature::appendNgram(const Word& word, bool& skip, stringstream } } -FFState* TargetNgramFeature::EvaluateChart(const ChartHypothesis& cur_hypo, int featureId, ScoreComponentCollection* accumulator) const +FFState* TargetNgramFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureId, ScoreComponentCollection* accumulator) const { vector contextFactor; contextFactor.reserve(m_n); diff --git a/moses/FF/TargetNgramFeature.h b/moses/FF/TargetNgramFeature.h index 8e91a08b2..e87252670 100644 --- a/moses/FF/TargetNgramFeature.h +++ b/moses/FF/TargetNgramFeature.h @@ -186,20 +186,20 @@ public: virtual const FFState* EmptyHypothesisState(const InputType &input) const; - virtual FFState* Evaluate(const Hypothesis& cur_hypo, const FFState* prev_state, + virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; - virtual FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureId, + virtual FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureId, ScoreComponentCollection* accumulator) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/TargetWordInsertionFeature.cpp b/moses/FF/TargetWordInsertionFeature.cpp index 7bb1ae6e9..c8db6bfe3 100644 --- a/moses/FF/TargetWordInsertionFeature.cpp +++ b/moses/FF/TargetWordInsertionFeature.cpp @@ -53,7 +53,7 @@ void TargetWordInsertionFeature::Load() m_unrestricted = false; } -void TargetWordInsertionFeature::Evaluate(const Phrase &source +void TargetWordInsertionFeature::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/TargetWordInsertionFeature.h b/moses/FF/TargetWordInsertionFeature.h index eedde61b2..06fa25400 100644 --- a/moses/FF/TargetWordInsertionFeature.h +++ b/moses/FF/TargetWordInsertionFeature.h @@ -28,21 +28,21 @@ public: void Load(); - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} diff --git a/moses/FF/TreeStructureFeature.cpp b/moses/FF/TreeStructureFeature.cpp index 25490a470..c0505edd6 100644 --- a/moses/FF/TreeStructureFeature.cpp +++ b/moses/FF/TreeStructureFeature.cpp @@ -4,7 +4,6 @@ #include "moses/Hypothesis.h" #include "moses/ChartHypothesis.h" #include "moses/TargetPhrase.h" -#include #include #include "moses/PP/TreeStructurePhraseProperty.h" @@ -267,14 +266,13 @@ void TreeStructureFeature::AddNTLabels(TreePointer root) const { } } -FFState* TreeStructureFeature::EvaluateChart(const ChartHypothesis& cur_hypo +FFState* TreeStructureFeature::EvaluateWhenApplied(const ChartHypothesis& cur_hypo , int featureID /* used to index the state in the previous hypotheses */ , ScoreComponentCollection* accumulator) const { - boost::shared_ptr property; - if (cur_hypo.GetCurrTargetPhrase().GetProperty("Tree", property)) { - const std::string &tree = property->GetValueString(); - TreePointer mytree (new InternalTree(tree)); + if (const PhraseProperty *property = cur_hypo.GetCurrTargetPhrase().GetProperty("Tree")) { + const std::string *tree = property->GetValueString(); + TreePointer mytree (new InternalTree(*tree)); if (m_labelset) { AddNTLabels(mytree); diff --git a/moses/FF/TreeStructureFeature.h b/moses/FF/TreeStructureFeature.h index 0fbf0f9ea..a81d604bb 100644 --- a/moses/FF/TreeStructureFeature.h +++ b/moses/FF/TreeStructureFeature.h @@ -152,21 +152,21 @@ public: return true; } - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const {}; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {}; - FFState* Evaluate( + FFState* EvaluateWhenApplied( const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const {UTIL_THROW(util::Exception, "Not implemented");}; - FFState* EvaluateChart( + FFState* EvaluateWhenApplied( const ChartHypothesis& /* cur_hypo */, int /* featureID - used to index the state in the previous hypotheses */, ScoreComponentCollection* accumulator) const; diff --git a/moses/FF/UnknownWordPenaltyProducer.h b/moses/FF/UnknownWordPenaltyProducer.h index 3b48f4380..8850641e5 100644 --- a/moses/FF/UnknownWordPenaltyProducer.h +++ b/moses/FF/UnknownWordPenaltyProducer.h @@ -31,20 +31,20 @@ public: } std::vector DefaultWeights() const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/WordPenaltyProducer.cpp b/moses/FF/WordPenaltyProducer.cpp index 6dea01b72..1e191d040 100644 --- a/moses/FF/WordPenaltyProducer.cpp +++ b/moses/FF/WordPenaltyProducer.cpp @@ -17,7 +17,7 @@ WordPenaltyProducer::WordPenaltyProducer(const std::string &line) s_instance = this; } -void WordPenaltyProducer::Evaluate(const Phrase &source +void WordPenaltyProducer::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/FF/WordPenaltyProducer.h b/moses/FF/WordPenaltyProducer.h index ffd921677..e62877307 100644 --- a/moses/FF/WordPenaltyProducer.h +++ b/moses/FF/WordPenaltyProducer.h @@ -27,17 +27,17 @@ public: return true; } - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const {} - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const {} - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec diff --git a/moses/FF/WordTranslationFeature.cpp b/moses/FF/WordTranslationFeature.cpp index 554107c32..7a98ad4c8 100644 --- a/moses/FF/WordTranslationFeature.cpp +++ b/moses/FF/WordTranslationFeature.cpp @@ -137,7 +137,7 @@ void WordTranslationFeature::Load() } } -void WordTranslationFeature::Evaluate +void WordTranslationFeature::EvaluateWhenApplied (const Hypothesis& hypo, ScoreComponentCollection* accumulator) const { @@ -349,7 +349,7 @@ void WordTranslationFeature::Evaluate } } -void WordTranslationFeature::EvaluateChart( +void WordTranslationFeature::EvaluateWhenApplied( const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const { diff --git a/moses/FF/WordTranslationFeature.h b/moses/FF/WordTranslationFeature.h index 072ba1d6a..c213d8eb3 100644 --- a/moses/FF/WordTranslationFeature.h +++ b/moses/FF/WordTranslationFeature.h @@ -48,19 +48,19 @@ public: return new DummyState(); } - void Evaluate(const Hypothesis& hypo, + void EvaluateWhenApplied(const Hypothesis& hypo, ScoreComponentCollection* accumulator) const; - void EvaluateChart(const ChartHypothesis &hypo, + void EvaluateWhenApplied(const ChartHypothesis &hypo, ScoreComponentCollection* accumulator) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore = NULL) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/Hypothesis.cpp b/moses/Hypothesis.cpp index 400fd0e0f..61e7c3f71 100644 --- a/moses/Hypothesis.cpp +++ b/moses/Hypothesis.cpp @@ -211,7 +211,7 @@ void Hypothesis::EvaluateWith(const StatefulFeatureFunction &sfff, { const StaticData &staticData = StaticData::Instance(); if (! staticData.IsFeatureFunctionIgnored( sfff )) { - m_ffStates[state_idx] = sfff.Evaluate( + m_ffStates[state_idx] = sfff.EvaluateWhenApplied( *this, m_prevHypo ? m_prevHypo->m_ffStates[state_idx] : NULL, &m_scoreBreakdown); @@ -222,7 +222,7 @@ void Hypothesis::EvaluateWith(const StatelessFeatureFunction& slff) { const StaticData &staticData = StaticData::Instance(); if (! staticData.IsFeatureFunctionIgnored( slff )) { - slff.Evaluate(*this, &m_scoreBreakdown); + slff.EvaluateWhenApplied(*this, &m_scoreBreakdown); } } @@ -254,7 +254,7 @@ void Hypothesis::Evaluate(const SquareMatrix &futureScore) const StatefulFeatureFunction &ff = *ffs[i]; const StaticData &staticData = StaticData::Instance(); if (! staticData.IsFeatureFunctionIgnored(ff)) { - m_ffStates[i] = ff.Evaluate(*this, + m_ffStates[i] = ff.EvaluateWhenApplied(*this, m_prevHypo ? m_prevHypo->m_ffStates[i] : NULL, &m_scoreBreakdown); } diff --git a/moses/Incremental.cpp b/moses/Incremental.cpp index 4e593df7e..c8a48d425 100644 --- a/moses/Incremental.cpp +++ b/moses/Incremental.cpp @@ -327,7 +327,7 @@ void PhraseAndFeatures(const search::Applied final, Phrase &phrase, ScoreCompone const LanguageModel &model = LanguageModel::GetFirstLM(); model.CalcScore(phrase, full, ignored_ngram, ignored_oov); - // CalcScore transforms, but EvaluateChart doesn't. + // CalcScore transforms, but EvaluateWhenApplied doesn't. features.Assign(&model, full); } diff --git a/moses/InputPath.cpp b/moses/InputPath.cpp index f00f1a7a4..523b03d53 100644 --- a/moses/InputPath.cpp +++ b/moses/InputPath.cpp @@ -85,7 +85,7 @@ size_t InputPath::GetTotalRuleSize() const size_t ret = 0; std::map >::const_iterator iter; for (iter = m_targetPhrases.begin(); iter != m_targetPhrases.end(); ++iter) { - const PhraseDictionary *pt = iter->first; + // const PhraseDictionary *pt = iter->first; const TargetPhraseCollection *tpColl = iter->second.first; if (tpColl) { diff --git a/moses/Jamfile b/moses/Jamfile index cc65f56ea..60ab877b0 100644 --- a/moses/Jamfile +++ b/moses/Jamfile @@ -10,7 +10,14 @@ if $(with-dlib) { dlib = ; } -alias headers : ../util//kenutil : : : $(max-factors) $(dlib) ; +with-lbllm = [ option.get "with-lbllm" ] ; +if $(with-lbllm) { + lbllm2 = -std=c++0x LM_LBL $(with-lbllm)/src $(with-lbllm)/3rdparty/eigen-3 ; +} else { + lbllm2 = ; +} + +alias headers : ../util//kenutil : : : $(max-factors) $(dlib) $(lbllm2) ; alias ThreadPool : ThreadPool.cpp ; alias Util : Util.cpp Timer.cpp ; @@ -69,10 +76,11 @@ lib moses : : #exceptions ThreadPool.cpp SyntacticLanguageModel.cpp - *Test.cpp Mock*.cpp + *Test.cpp Mock*.cpp FF/*Test.cpp FF/Factory.cpp ] -headers FF_Factory.o LM//LM TranslationModel/CompactPT//CompactPT synlm ThreadPool +headers FF_Factory.o LM//LM TranslationModel/CompactPT//CompactPT TranslationModel/ProbingPT//ProbingPT synlm ThreadPool + ..//search ../util/double-conversion//double-conversion ..//z ../OnDiskPt//OnDiskPt $(TOP)//boost_iostreams mmlib : @@ -84,5 +92,5 @@ alias headers-to-install : [ glob-tree *.h ] ; import testing ; -unit-test moses_test : [ glob *Test.cpp Mock*.cpp ] moses headers ..//z ../OnDiskPt//OnDiskPt ..//boost_unit_test_framework ; +unit-test moses_test : [ glob *Test.cpp Mock*.cpp FF/*Test.cpp ] moses headers ..//z ../OnDiskPt//OnDiskPt ..//boost_unit_test_framework ; diff --git a/moses/LM/Base.cpp b/moses/LM/Base.cpp index f59b5e31b..db71119d5 100644 --- a/moses/LM/Base.cpp +++ b/moses/LM/Base.cpp @@ -69,7 +69,7 @@ void LanguageModel::ReportHistoryOrder(std::ostream &out,const Phrase &phrase) c // out << "ReportHistoryOrder not implemented"; } -void LanguageModel::Evaluate(const Phrase &source +void LanguageModel::EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/LM/Base.h b/moses/LM/Base.h index abae5de24..2be19e5bd 100644 --- a/moses/LM/Base.h +++ b/moses/LM/Base.h @@ -87,11 +87,11 @@ public: virtual void IncrementalCallback(Incremental::Manager &manager) const; virtual void ReportHistoryOrder(std::ostream &out,const Phrase &phrase) const; - virtual void Evaluate(const Phrase &source + virtual void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const; - void Evaluate(const InputType &input + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec diff --git a/moses/LM/DALMWrapper.cpp b/moses/LM/DALMWrapper.cpp index 420efd9e8..68b3050de 100644 --- a/moses/LM/DALMWrapper.cpp +++ b/moses/LM/DALMWrapper.cpp @@ -288,7 +288,7 @@ void LanguageModelDALM::CalcScore(const Phrase &phrase, float &fullScore, float ngramScore = TransformLMScore(ngramScore); } -FFState *LanguageModelDALM::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const{ +FFState *LanguageModelDALM::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const{ // In this function, we only compute the LM scores of n-grams that overlap a // phrase boundary. Phrase-internal scores are taken directly from the // translation option. @@ -339,7 +339,7 @@ FFState *LanguageModelDALM::Evaluate(const Hypothesis &hypo, const FFState *ps, return dalm_state; } -FFState *LanguageModelDALM::EvaluateChart(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const{ +FFState *LanguageModelDALM::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const{ // initialize language model context state DALMChartState *newState = new DALMChartState(); DALM::State &state = newState->GetRightContext(); diff --git a/moses/LM/DALMWrapper.h b/moses/LM/DALMWrapper.h index c791eeea6..ad53819c0 100644 --- a/moses/LM/DALMWrapper.h +++ b/moses/LM/DALMWrapper.h @@ -34,9 +34,9 @@ public: virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const; - virtual FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; + virtual FFState *EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; - virtual FFState *EvaluateChart(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const; + virtual FFState *EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *out) const; virtual bool IsUseable(const FactorMask &mask) const; diff --git a/moses/LM/Implementation.cpp b/moses/LM/Implementation.cpp index ef09fbc77..bd5bd1834 100644 --- a/moses/LM/Implementation.cpp +++ b/moses/LM/Implementation.cpp @@ -134,7 +134,7 @@ void LanguageModelImplementation::CalcScore(const Phrase &phrase, float &fullSco } } -FFState *LanguageModelImplementation::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const +FFState *LanguageModelImplementation::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const { // In this function, we only compute the LM scores of n-grams that overlap a // phrase boundary. Phrase-internal scores are taken directly from the @@ -222,7 +222,7 @@ FFState *LanguageModelImplementation::Evaluate(const Hypothesis &hypo, const FFS return res; } -FFState* LanguageModelImplementation::EvaluateChart(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection* out) const +FFState* LanguageModelImplementation::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection* out) const { LanguageModelChartState *ret = new LanguageModelChartState(hypo, featureID, GetNGramOrder()); // data structure for factored context phrase (history and predicted word) diff --git a/moses/LM/Implementation.h b/moses/LM/Implementation.h index a39f5e42b..5eb8fb209 100644 --- a/moses/LM/Implementation.h +++ b/moses/LM/Implementation.h @@ -89,9 +89,9 @@ public: void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const; - FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; + FFState *EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; - FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator) const; + FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator) const; void updateChartScore(float *prefixScore, float *finalScore, float score, size_t wordPos) const; diff --git a/moses/LM/Jamfile b/moses/LM/Jamfile index 4f964ddd8..87b0c1b36 100644 --- a/moses/LM/Jamfile +++ b/moses/LM/Jamfile @@ -90,6 +90,18 @@ if $(with-nplm) { lmmacros += LM_NEURAL ; } +#LBLLM +local with-lbllm = [ option.get "with-lbllm" ] ; +if $(with-lbllm) { + lib lbl : : $(with-lbllm)/lib $(with-lbllm)/lib64 ; + obj LBLLM.o : oxlm/LBLLM.cpp lbl ..//headers : $(with-lbllm)/src $(with-lbllm)/3rdparty/eigen-3 ; + obj Mapper.o : oxlm/Mapper.cpp lbl ..//headers : $(with-lbllm)/src $(with-lbllm)/3rdparty/eigen-3 ; + alias lbllm : LBLLM.o Mapper.o lbl : : : -std=c++0x LM_LBL ; + dependencies += lbllm ; + lmmacros += LM_LBL ; +} + + #DALM local with-dalm = [ option.get "with-dalm" ] ; if $(with-dalm) { diff --git a/moses/LM/Ken.cpp b/moses/LM/Ken.cpp index 2dfb58c23..e69746084 100644 --- a/moses/LM/Ken.cpp +++ b/moses/LM/Ken.cpp @@ -79,7 +79,7 @@ struct KenLMState : public FFState { // // FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; // -// FFState *EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const; +// FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const; // // void IncrementalCallback(Incremental::Manager &manager) const { // manager.LMCallback(*m_ngram, m_lmIdLookup); @@ -229,7 +229,7 @@ template void LanguageModelKen::CalcScore(const Phrase &phr fullScore = TransformLMScore(fullScore); } -template FFState *LanguageModelKen::Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const +template FFState *LanguageModelKen::EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const { const lm::ngram::State &in_state = static_cast(*ps).state; @@ -307,7 +307,7 @@ private: lm::ngram::ChartState m_state; }; -template FFState *LanguageModelKen::EvaluateChart(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *accumulator) const +template FFState *LanguageModelKen::EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection *accumulator) const { LanguageModelChartStateKenLM *newState = new LanguageModelChartStateKenLM(); lm::ngram::RuleScore ruleScore(*m_ngram, newState->GetChartState()); diff --git a/moses/LM/Ken.h b/moses/LM/Ken.h index e5950f591..2f473b697 100644 --- a/moses/LM/Ken.h +++ b/moses/LM/Ken.h @@ -55,9 +55,9 @@ public: virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const; - virtual FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; + virtual FFState *EvaluateWhenApplied(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out) const; - virtual FFState *EvaluateChart(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const; + virtual FFState *EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection *accumulator) const; virtual void IncrementalCallback(Incremental::Manager &manager) const; virtual void ReportHistoryOrder(std::ostream &out,const Phrase &phrase) const; diff --git a/moses/LM/LDHT.cpp b/moses/LM/LDHT.cpp index 61226208c..1d0331df5 100644 --- a/moses/LM/LDHT.cpp +++ b/moses/LM/LDHT.cpp @@ -97,7 +97,7 @@ public: FFState* Evaluate(const Hypothesis& hypo, const FFState* input_state, ScoreComponentCollection* score_output) const; - FFState* EvaluateChart(const ChartHypothesis& hypo, + FFState* EvaluateWhenApplied(const ChartHypothesis& hypo, int featureID, ScoreComponentCollection* accumulator) const; @@ -392,7 +392,7 @@ FFState* LanguageModelLDHT::Evaluate( return state; } -FFState* LanguageModelLDHT::EvaluateChart( +FFState* LanguageModelLDHT::EvaluateWhenApplied( const ChartHypothesis& hypo, int featureID, ScoreComponentCollection* accumulator) const diff --git a/moses/LM/SingleFactor.cpp b/moses/LM/SingleFactor.cpp index 74b8f4fe5..1efb13f16 100644 --- a/moses/LM/SingleFactor.cpp +++ b/moses/LM/SingleFactor.cpp @@ -87,6 +87,17 @@ void LanguageModelSingleFactor::SetParameter(const std::string& key, const std:: } } +std::string LanguageModelSingleFactor::DebugContextFactor(const std::vector &contextFactor) const +{ + std::string ret; + for (size_t i = 0; i < contextFactor.size(); ++i) { + const Word &word = *contextFactor[i]; + ret += word.ToString(); + } + + return ret; +} + } diff --git a/moses/LM/SingleFactor.h b/moses/LM/SingleFactor.h index eeb5cdbef..fd1d893e6 100644 --- a/moses/LM/SingleFactor.h +++ b/moses/LM/SingleFactor.h @@ -67,6 +67,8 @@ public: virtual LMResult GetValueForgotState(const std::vector &contextFactor, FFState &outState) const; virtual LMResult GetValue(const std::vector &contextFactor, State* finalState = NULL) const = 0; + + std::string DebugContextFactor(const std::vector &contextFactor) const; }; diff --git a/moses/LM/oxlm/LBLLM.cpp b/moses/LM/oxlm/LBLLM.cpp new file mode 100644 index 000000000..20f1a2149 --- /dev/null +++ b/moses/LM/oxlm/LBLLM.cpp @@ -0,0 +1,12 @@ + +#include "LBLLM.h" + +using namespace std; + +namespace Moses +{ + +} + + + diff --git a/moses/LM/oxlm/LBLLM.h b/moses/LM/oxlm/LBLLM.h new file mode 100644 index 000000000..07ed9a8d3 --- /dev/null +++ b/moses/LM/oxlm/LBLLM.h @@ -0,0 +1,122 @@ +// $Id$ +#pragma once + +#include +#include +#include "moses/LM/SingleFactor.h" +#include "moses/FactorCollection.h" + +// lbl stuff +#include "corpus/corpus.h" +#include "lbl/lbl_features.h" +#include "lbl/model.h" +#include "lbl/process_identifier.h" +#include "lbl/query_cache.h" + +#include "Mapper.h" + +namespace Moses +{ + + +template +class LBLLM : public LanguageModelSingleFactor +{ +protected: + +public: + LBLLM(const std::string &line) + :LanguageModelSingleFactor(line) + { + ReadParameters(); + + FactorCollection &factorCollection = FactorCollection::Instance(); + + // needed by parent language model classes. Why didn't they set these themselves? + m_sentenceStart = factorCollection.AddFactor(Output, m_factorType, BOS_); + m_sentenceStartWord[m_factorType] = m_sentenceStart; + + m_sentenceEnd = factorCollection.AddFactor(Output, m_factorType, EOS_); + m_sentenceEndWord[m_factorType] = m_sentenceEnd; + } + + ~LBLLM() + {} + + void Load() + { + model.load(m_filePath); + + config = model.getConfig(); + int context_width = config->ngram_order - 1; + // For each state, we store at most context_width word ids to the left and + // to the right and a kSTAR separator. The last bit represents the actual + // size of the state. + //int max_state_size = (2 * context_width + 1) * sizeof(int) + 1; + //FeatureFunction::SetStateSize(max_state_size); + + dict = model.getDict(); + mapper = boost::make_shared(dict); + //stateConverter = boost::make_shared(max_state_size - 1); + //ruleConverter = boost::make_shared(mapper, stateConverter); + + kSTART = dict.Convert(""); + kSTOP = dict.Convert(""); + kUNKNOWN = dict.Convert(""); + } + + + virtual LMResult GetValue(const std::vector &contextFactor, State* finalState = 0) const + { + std::vector context; + int word; + mapper->convert(contextFactor, context, word); + + size_t context_width = m_nGramOrder - 1; + + if (!context.empty() && context.back() == kSTART) { + context.resize(context_width, kSTART); + } else { + context.resize(context_width, kUNKNOWN); + } + + + double score; + score = model.predict(word, context); + + /* + std::string str = DebugContextFactor(contextFactor); + std::cerr << "contextFactor=" << str << " " << score << std::endl; + */ + + LMResult ret; + ret.score = score; + ret.unknown = (word == kUNKNOWN); + + // calc state from hash of last n-1 words + size_t seed = 0; + boost::hash_combine(seed, word); + for (size_t i = 0; i < context.size() && i < context_width - 1; ++i) { + int id = context[i]; + boost::hash_combine(seed, id); + } + + (*finalState) = (State*) seed; + return ret; + } + +protected: + oxlm::Dict dict; + boost::shared_ptr config; + Model model; + + int kSTART; + int kSTOP; + int kUNKNOWN; + + boost::shared_ptr mapper; + +}; + + +} diff --git a/moses/LM/oxlm/Mapper.cpp b/moses/LM/oxlm/Mapper.cpp new file mode 100644 index 000000000..f1363ccf0 --- /dev/null +++ b/moses/LM/oxlm/Mapper.cpp @@ -0,0 +1,67 @@ +#include "Mapper.h" +#include "moses/FactorCollection.h" + +using namespace std; + +namespace Moses +{ +OXLMMapper::OXLMMapper(const oxlm::Dict& dict) : dict(dict) +{ + for (int i = 0; i < dict.size(); ++i) { + const string &str = dict.Convert(i); + FactorCollection &fc = FactorCollection::Instance(); + const Moses::Factor *factor = fc.AddFactor(str, false); + moses2lbl[factor] = i; + + //add(i, TD::Convert()); + } + + kUNKNOWN = this->dict.Convert(""); +} + +int OXLMMapper::convert(const Moses::Factor *factor) const +{ + Coll::const_iterator iter; + iter = moses2lbl.find(factor); + if (iter == moses2lbl.end()) { + return kUNKNOWN; + } + else { + int ret = iter->second; + return ret; + } +} + +std::vector OXLMMapper::convert(const Phrase &phrase) const +{ + size_t size = phrase.GetSize(); + vector ret(size); + + for (size_t i = 0; i < size; ++i) { + const Moses::Factor *factor = phrase.GetFactor(i, 0); + int id = convert(factor); + ret[i] = id; + } + return ret; +} + +void OXLMMapper::convert(const std::vector &contextFactor, std::vector &ids, int &word) const +{ + size_t size = contextFactor.size(); + + ids.resize(size - 1); + + for (size_t i = 0; i < size - 1; ++i) { + const Moses::Factor *factor = contextFactor[i]->GetFactor(0); + int id = convert(factor); + ids[i] = id; + } + std::reverse(ids.begin(), ids.end()); + + const Moses::Factor *factor = contextFactor.back()->GetFactor(0); + word = convert(factor); + +} + +} // namespace + diff --git a/moses/LM/oxlm/Mapper.h b/moses/LM/oxlm/Mapper.h new file mode 100644 index 000000000..79cbf7b5f --- /dev/null +++ b/moses/LM/oxlm/Mapper.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include "corpus/corpus.h" +#include "moses/Factor.h" +#include "moses/Phrase.h" + +namespace Moses +{ +class OXLMMapper +{ +public: + OXLMMapper(const oxlm::Dict& dict); + + int convert(const Moses::Factor *factor) const; + std::vector convert(const Phrase &phrase) const; + void convert(const std::vector &contextFactor, std::vector &ids, int &word) const; + +private: + void add(int lbl_id, int cdec_id); + + oxlm::Dict dict; + typedef std::map Coll; + Coll moses2lbl; + int kUNKNOWN; + +}; + +/** + * Wraps the feature values computed from the LBL language model. + */ +struct LBLFeatures { + LBLFeatures() : LMScore(0), OOVScore(0) {} + LBLFeatures(double lm_score, double oov_score) + : LMScore(lm_score), OOVScore(oov_score) {} + LBLFeatures& operator+=(const LBLFeatures& other) { + LMScore += other.LMScore; + OOVScore += other.OOVScore; + return *this; + } + + double LMScore; + double OOVScore; +}; + +} diff --git a/moses/PDTAimp.h b/moses/PDTAimp.h index 999fbb1e0..2a7943ce2 100644 --- a/moses/PDTAimp.h +++ b/moses/PDTAimp.h @@ -233,7 +233,7 @@ public: //InputFileStream in(filePath); //m_dict->Create(in,filePath); } - TRACE_ERR( "reading bin ttable\n"); + VERBOSE(1,"reading bin ttable\n"); // m_dict->Read(filePath); bool res=m_dict->Read(filePath); if (!res) { diff --git a/moses/PP/CountsPhraseProperty.cpp b/moses/PP/CountsPhraseProperty.cpp new file mode 100644 index 000000000..b64366733 --- /dev/null +++ b/moses/PP/CountsPhraseProperty.cpp @@ -0,0 +1,38 @@ +#include "moses/PP/CountsPhraseProperty.h" +#include +#include + +namespace Moses +{ + +void CountsPhraseProperty::ProcessValue(const std::string &value) +{ + std::istringstream tokenizer(value); + + if (! (tokenizer >> m_targetMarginal)) { // first token: countE + UTIL_THROW2("CountsPhraseProperty: Not able to read target marginal. Flawed property?"); + } + assert( m_targetMarginal > 0 ); + + if (! (tokenizer >> m_sourceMarginal)) { // first token: countF + UTIL_THROW2("CountsPhraseProperty: Not able to read source marginal. Flawed property?"); + } + assert( m_sourceMarginal > 0 ); + + if (! (tokenizer >> m_jointCount)) { // first token: countEF + UTIL_THROW2("CountsPhraseProperty: Not able to read joint count. Flawed property?"); + } + assert( m_jointCount > 0 ); +}; + +std::ostream& operator<<(std::ostream &out, const CountsPhraseProperty &obj) +{ + out << "Count property=" + << obj.GetTargetMarginal() << " " + << obj.GetSourceMarginal() << " " + << obj.GetJointCount(); + return out; +} + +} // namespace Moses + diff --git a/moses/PP/CountsPhraseProperty.h b/moses/PP/CountsPhraseProperty.h new file mode 100644 index 000000000..4f6fbcfa8 --- /dev/null +++ b/moses/PP/CountsPhraseProperty.h @@ -0,0 +1,62 @@ + +#pragma once + +#include "moses/PP/PhraseProperty.h" +#include "util/exception.hh" +#include +#include + +namespace Moses +{ + +// A simple phrase property class to access the three phrase count values. +// +// The counts are usually not needed during decoding and are not loaded +// from the phrase table. This is just a workaround that can make them +// available to features which have a use for them. +// +// If you need access to the counts, copy the two marginal counts and the +// joint count into an additional information property with key "Counts", +// e.g. using awk: +// +// $ zcat phrase-table.gz | awk -F' \|\|\| ' '{printf("%s {{Counts %s}}\n",$0,$5);}' | gzip -c > phrase-table.withCountsPP.gz +// +// CountsPhraseProperty reads them from the phrase table and provides +// methods GetSourceMarginal(), GetTargetMarginal(), GetJointCount(). + + +class CountsPhraseProperty : public PhraseProperty +{ + friend std::ostream& operator<<(std::ostream &, const CountsPhraseProperty &); + +public: + + CountsPhraseProperty() {}; + + virtual void ProcessValue(const std::string &value); + + size_t GetSourceMarginal() const { + return m_sourceMarginal; + } + + size_t GetTargetMarginal() const { + return m_targetMarginal; + } + + float GetJointCount() const { + return m_jointCount; + } + + virtual const std::string *GetValueString() const { + UTIL_THROW2("CountsPhraseProperty: value string not available in this phrase property"); + return NULL; + }; + +protected: + + float m_sourceMarginal, m_targetMarginal, m_jointCount; + +}; + +} // namespace Moses + diff --git a/moses/PP/Factory.cpp b/moses/PP/Factory.cpp index 61e96a7f2..4e9bfbf0e 100644 --- a/moses/PP/Factory.cpp +++ b/moses/PP/Factory.cpp @@ -4,7 +4,11 @@ #include #include +#include "moses/PP/CountsPhraseProperty.h" +#include "moses/PP/SourceLabelsPhraseProperty.h" #include "moses/PP/TreeStructurePhraseProperty.h" +#include "moses/PP/SpanLengthPhraseProperty.h" +#include "moses/PP/NonTermContextProperty.h" namespace Moses { @@ -34,8 +38,8 @@ template class DefaultPhrasePropertyCreator : public PhrasePropertyCre { public: boost::shared_ptr CreateProperty(const std::string &value) { - P* property = new P(value); - property->ProcessValue(); + P* property = new P(); + property->ProcessValue(value); return Create(property); } }; @@ -50,8 +54,11 @@ PhrasePropertyFactory::PhrasePropertyFactory() // Properties with different key than class. #define MOSES_PNAME2(name, type) Add(name, new DefaultPhrasePropertyCreator< type >()); + MOSES_PNAME2("Counts", CountsPhraseProperty); + MOSES_PNAME2("SourceLabels", SourceLabelsPhraseProperty); MOSES_PNAME2("Tree",TreeStructurePhraseProperty); - + MOSES_PNAME2("SpanLength", SpanLengthPhraseProperty); + MOSES_PNAME2("NonTermContext", NonTermContextProperty); } PhrasePropertyFactory::~PhrasePropertyFactory() diff --git a/moses/PP/NonTermContextProperty.cpp b/moses/PP/NonTermContextProperty.cpp new file mode 100644 index 000000000..df5e88d8e --- /dev/null +++ b/moses/PP/NonTermContextProperty.cpp @@ -0,0 +1,137 @@ +#include "moses/PP/NonTermContextProperty.h" +#include +#include +#include "moses/Util.h" +#include "moses/FactorCollection.h" + +using namespace std; + +namespace Moses +{ +NonTermContextProperty::NonTermContextProperty() +{ +} + +NonTermContextProperty::~NonTermContextProperty() +{ + //RemoveAllInColl(m_probStores); +} + +void NonTermContextProperty::ProcessValue(const std::string &value) +{ + vector toks; + Tokenize(toks, value); + + FactorCollection &fc = FactorCollection::Instance(); + + size_t numNT = Scan(toks[0]); + m_probStores.resize(numNT); + + size_t ind = 1; + while (ind < toks.size()) { + vector factors; + + for (size_t nt = 0; nt < numNT; ++nt) { + size_t ntInd = Scan(toks[ind]); + assert(nt == ntInd); + ++ind; + + for (size_t contextInd = 0; contextInd < 4; ++contextInd) { + //cerr << "toks[" << ind << "]=" << toks[ind] << endl; + const Factor *factor = fc.AddFactor(toks[ind], false); + factors.push_back(factor); + ++ind; + } + } + + // done with the context. Just get the count and put it all into data structures + // cerr << "count=" << toks[ind] << endl; + float count = Scan(toks[ind]); + ++ind; + + for (size_t i = 0; i < factors.size(); ++i) { + size_t ntInd = i / 4; + size_t contextInd = i % 4; + const Factor *factor = factors[i]; + AddToMap(ntInd, contextInd, factor, count); + } + } +} + +void NonTermContextProperty::AddToMap(size_t ntIndex, size_t index, const Factor *factor, float count) +{ + if (ntIndex <= m_probStores.size()) { + m_probStores.resize(ntIndex + 1); + } + + ProbStore &probStore = m_probStores[ntIndex]; + probStore.AddToMap(index, factor, count); +} + +float NonTermContextProperty::GetProb(size_t ntInd, + size_t contextInd, + const Factor *factor, + float smoothConstant) const +{ + UTIL_THROW_IF2(ntInd >= m_probStores.size(), "Invalid nt index=" << ntInd); + const ProbStore &probStore = m_probStores[ntInd]; + float ret = probStore.GetProb(contextInd, factor, smoothConstant); + return ret; +} + +////////////////////////////////////////// + +void NonTermContextProperty::ProbStore::AddToMap(size_t index, const Factor *factor, float count) +{ + Map &map = m_vec[index]; + + Map::iterator iter = map.find(factor); + if (iter == map.end()) { + map[factor] = count; + } + else { + float &currCount = iter->second; + currCount += count; + } + + m_totalCount += count; +} + + +float NonTermContextProperty::ProbStore::GetProb(size_t contextInd, + const Factor *factor, + float smoothConstant) const +{ + float count = GetCount(contextInd, factor, smoothConstant); + float total = GetTotalCount(contextInd, smoothConstant); + float ret = count / total; + return ret; +} + +float NonTermContextProperty::ProbStore::GetCount(size_t contextInd, + const Factor *factor, + float smoothConstant) const +{ + const Map &map = m_vec[contextInd]; + + float count = smoothConstant; + Map::const_iterator iter = map.find(factor); + if (iter == map.end()) { + // nothing + } + else { + count += iter->second; + } + + return count; +} + +float NonTermContextProperty::ProbStore::GetTotalCount(size_t contextInd, float smoothConstant) const +{ + const Map &map = m_vec[contextInd]; + return m_totalCount + smoothConstant * map.size(); +} + + +} // namespace Moses + diff --git a/moses/PP/NonTermContextProperty.h b/moses/PP/NonTermContextProperty.h new file mode 100644 index 000000000..56db9cb32 --- /dev/null +++ b/moses/PP/NonTermContextProperty.h @@ -0,0 +1,73 @@ + +#pragma once + +#include "moses/PP/PhraseProperty.h" +#include "util/exception.hh" +#include +#include +#include +#include + +namespace Moses +{ +class Factor; + +class NonTermContextProperty : public PhraseProperty +{ +public: + + NonTermContextProperty(); + ~NonTermContextProperty(); + + virtual void ProcessValue(const std::string &value); + + virtual const std::string *GetValueString() const { + UTIL_THROW2("NonTermContextProperty: value string not available in this phrase property"); + return NULL; + }; + + float GetProb(size_t ntInd, + size_t contextInd, + const Factor *factor, + float smoothConstant) const; + +protected: + + class ProbStore { + typedef std::map Map; // map word -> prob + typedef std::vector Vec; // left outside, left inside, right inside, right outside + Vec m_vec; + float m_totalCount; + + float GetCount(size_t contextInd, + const Factor *factor, + float smoothConstant) const; + float GetTotalCount(size_t contextInd, float smoothConstant) const; + + public: + + ProbStore() + :m_vec(4) + ,m_totalCount(0) + {} + + float GetProb(size_t contextInd, + const Factor *factor, + float smoothConstant) const; + + float GetSize(size_t index) const + { return m_vec[index].size(); } + + void AddToMap(size_t index, const Factor *factor, float count); + + }; + + // by nt index + std::vector m_probStores; + + void AddToMap(size_t ntIndex, size_t index, const Factor *factor, float count); + +}; + +} // namespace Moses + diff --git a/moses/PP/PhraseProperty.cpp b/moses/PP/PhraseProperty.cpp new file mode 100644 index 000000000..614b39c60 --- /dev/null +++ b/moses/PP/PhraseProperty.cpp @@ -0,0 +1,13 @@ +#include "PhraseProperty.h" + +namespace Moses +{ + +std::ostream& operator<<(std::ostream &out, const PhraseProperty &obj) +{ + out << "Base phrase property"; + return out; +} + +} + diff --git a/moses/PP/PhraseProperty.h b/moses/PP/PhraseProperty.h index b977787b2..b7437369b 100644 --- a/moses/PP/PhraseProperty.h +++ b/moses/PP/PhraseProperty.h @@ -10,16 +10,19 @@ namespace Moses */ class PhraseProperty { + friend std::ostream& operator<<(std::ostream &, const PhraseProperty &); + public: - PhraseProperty(const std::string &value) : m_value(value) {}; + PhraseProperty() : m_value(NULL) {}; + ~PhraseProperty() { if ( m_value != NULL ) delete m_value; }; - virtual void ProcessValue() {}; + virtual void ProcessValue(const std::string &value) { m_value = new std::string(value); }; - const std::string &GetValueString() { return m_value; }; + virtual const std::string *GetValueString() const { return m_value; }; protected: - const std::string m_value; + std::string *m_value; }; diff --git a/moses/PP/SourceLabelsPhraseProperty.cpp b/moses/PP/SourceLabelsPhraseProperty.cpp new file mode 100644 index 000000000..bca5c9a30 --- /dev/null +++ b/moses/PP/SourceLabelsPhraseProperty.cpp @@ -0,0 +1,124 @@ +#include "moses/PP/SourceLabelsPhraseProperty.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Moses +{ + +void SourceLabelsPhraseProperty::ProcessValue(const std::string &value) +{ + std::istringstream tokenizer(value); + + if (! (tokenizer >> m_nNTs)) { // first token: number of non-terminals (incl. left-hand side) + UTIL_THROW2("SourceLabelsPhraseProperty: Not able to read number of non-terminals. Flawed property?"); + } + assert( m_nNTs > 0 ); + + if (! (tokenizer >> m_totalCount)) { // second token: overall rule count + UTIL_THROW2("SourceLabelsPhraseProperty: Not able to read overall rule count. Flawed property?"); + } + assert( m_totalCount > 0.0 ); + + + + // read source-labelled rule items + + std::priority_queue ruleLabelledCountsPQ; + + while (tokenizer.peek() != EOF) { + try { + + SourceLabelsPhrasePropertyItem item; + size_t numberOfLHSsGivenRHS = std::numeric_limits::max(); + + if (m_nNTs == 1) { + + item.m_sourceLabelsRHSCount = m_totalCount; + + } else { // rule has right-hand side non-terminals, i.e. it's a hierarchical rule + + for (size_t i=0; i> sourceLabelRHS) ) { // RHS source non-terminal label + UTIL_THROW2("SourceLabelsPhraseProperty: Not able to read right-hand side label index. Flawed property?"); + } + item.m_sourceLabelsRHS.push_back(sourceLabelRHS); + } + + if (! (tokenizer >> item.m_sourceLabelsRHSCount)) { + UTIL_THROW2("SourceLabelsPhraseProperty: Not able to read right-hand side count. Flawed property?"); + } + + if (! (tokenizer >> numberOfLHSsGivenRHS)) { + UTIL_THROW2("SourceLabelsPhraseProperty: Not able to read number of left-hand sides. Flawed property?"); + } + } + + for (size_t i=0; i> sourceLabelLHS)) { // LHS source non-terminal label + UTIL_THROW2("SourceLabelsPhraseProperty: Not able to read left-hand side label index. Flawed property?"); + } + float ruleSourceLabelledCount; + if (! (tokenizer >> ruleSourceLabelledCount)) { + UTIL_THROW2("SourceLabelsPhraseProperty: Not able to read count. Flawed property?"); + } + item.m_sourceLabelsLHSList.push_back( std::make_pair(sourceLabelLHS,ruleSourceLabelledCount) ); + ruleLabelledCountsPQ.push(ruleSourceLabelledCount); + } + + m_sourceLabelItems.push_back(item); + + } catch (const std::exception &e) { + UTIL_THROW2("SourceLabelsPhraseProperty: Read error. Flawed property?"); + } + } + + // keep only top N label vectors + const size_t N=50; + + if (ruleLabelledCountsPQ.size() > N) { + + float topNRuleLabelledCount = std::numeric_limits::max(); + for (size_t i=0; !ruleLabelledCountsPQ.empty() && i::iterator itemIter=m_sourceLabelItems.begin(); + while (itemIter!=m_sourceLabelItems.end()) { + if (itemIter->m_sourceLabelsRHSCount < topNRuleLabelledCount) { + itemIter = m_sourceLabelItems.erase(itemIter); + } else { + std::list< std::pair >::iterator itemLHSIter=(itemIter->m_sourceLabelsLHSList).begin(); + while (itemLHSIter!=(itemIter->m_sourceLabelsLHSList).end()) { + if (itemLHSIter->second < topNRuleLabelledCount) { + itemLHSIter = (itemIter->m_sourceLabelsLHSList).erase(itemLHSIter); + } else { + if (nKept >= N) { + itemLHSIter = (itemIter->m_sourceLabelsLHSList).erase(itemLHSIter,(itemIter->m_sourceLabelsLHSList).end()); + } else { + ++nKept; + ++itemLHSIter; + } + } + } + if ((itemIter->m_sourceLabelsLHSList).empty()) { + itemIter = m_sourceLabelItems.erase(itemIter); + } else { + ++itemIter; + } + } + } + } +}; + +} // namespace Moses + diff --git a/moses/PP/SourceLabelsPhraseProperty.h b/moses/PP/SourceLabelsPhraseProperty.h new file mode 100644 index 000000000..39b43ad3e --- /dev/null +++ b/moses/PP/SourceLabelsPhraseProperty.h @@ -0,0 +1,77 @@ + +#pragma once + +#include "moses/PP/PhraseProperty.h" +#include "util/exception.hh" +#include +#include + +namespace Moses +{ + +// Note that we require label tokens (strings) in the corresponding property values of phrase table entries +// to be replaced beforehand by indices (size_t) of a label vocabulary. (TODO: change that?) + +class SourceLabelsPhrasePropertyItem +{ +friend class SourceLabelsPhraseProperty; + +public: + SourceLabelsPhrasePropertyItem() {}; + + float GetSourceLabelsRHSCount() const + { + return m_sourceLabelsRHSCount; + }; + + const std::list &GetSourceLabelsRHS() const + { + return m_sourceLabelsRHS; + }; + + const std::list< std::pair > &GetSourceLabelsLHSList() const + { + return m_sourceLabelsLHSList; + }; + +private: + float m_sourceLabelsRHSCount; + std::list m_sourceLabelsRHS; // should be of size nNTs-1 (empty if initial rule, i.e. no right-hand side non-terminals) + std::list< std::pair > m_sourceLabelsLHSList; // list of left-hand sides for this right-hand side, with counts +}; + + +class SourceLabelsPhraseProperty : public PhraseProperty +{ +public: + SourceLabelsPhraseProperty() {}; + + virtual void ProcessValue(const std::string &value); + + size_t GetNumberOfNonTerminals() const { + return m_nNTs; + } + + float GetTotalCount() const { + return m_totalCount; + } + + const std::list &GetSourceLabelItems() const { + return m_sourceLabelItems; + }; + + virtual const std::string *GetValueString() const { + UTIL_THROW2("SourceLabelsPhraseProperty: value string not available in this phrase property"); + return NULL; + }; + +protected: + + size_t m_nNTs; + float m_totalCount; + + std::list m_sourceLabelItems; +}; + +} // namespace Moses + diff --git a/moses/PP/SpanLengthPhraseProperty.cpp b/moses/PP/SpanLengthPhraseProperty.cpp new file mode 100644 index 000000000..d45c7b919 --- /dev/null +++ b/moses/PP/SpanLengthPhraseProperty.cpp @@ -0,0 +1,127 @@ +#include "SpanLengthPhraseProperty.h" +#include "moses/Util.h" +#include "util/exception.hh" + +using namespace std; + +namespace Moses +{ +SpanLengthPhraseProperty::SpanLengthPhraseProperty() +{ +} + +void SpanLengthPhraseProperty::ProcessValue(const std::string &value) +{ + vector toks; + Tokenize(toks, value); + + set< vector > indices; + + for (size_t i = 0; i < toks.size(); ++i) { + const string &span = toks[i]; + + // is it a ntIndex,sourceSpan,targetSpan or count ? + vector toks; + Tokenize(toks, span, ","); + UTIL_THROW_IF2(toks.size() != 1 && toks.size() != 3, "Incorrect format for SpanLength: " << span); + + if (toks.size() == 1) { + float count = Scan(toks[0]); + Populate(indices, count); + + indices.clear(); + } + else { + indices.insert(toks); + } + } + + // totals + CalcTotals(m_source); + CalcTotals(m_target); +} + +void SpanLengthPhraseProperty::Populate(const set< vector > &indices, float count) +{ + set< vector >::const_iterator iter; + for (iter = indices.begin(); iter != indices.end(); ++iter) { + const vector &toksStr = *iter; + vector toks = Scan(toksStr); + UTIL_THROW_IF2(toks.size() != 3, "Incorrect format for SpanLength. Size is " << toks.size()); + + Populate(toks, count); + } +} + +void SpanLengthPhraseProperty::Populate(const std::vector &toks, float count) +{ + size_t ntInd = toks[0]; + size_t sourceLength = toks[1]; + size_t targetLength = toks[2]; + if (ntInd >= m_source.size() ) { + m_source.resize(ntInd + 1); + m_target.resize(ntInd + 1); + } + + Map &sourceMap = m_source[ntInd].first; + Map &targetMap = m_target[ntInd].first; + Populate(sourceMap, sourceLength, count); + Populate(targetMap, targetLength, count); +} + +void SpanLengthPhraseProperty::Populate(Map &map, size_t span, float count) +{ + Map::iterator iter; + iter = map.find(span); + if (iter != map.end()) { + float &value = iter->second; + value += count; + } + else { + map[span] = count; + } +} + +void SpanLengthPhraseProperty::CalcTotals(Vec &vec) +{ + for (size_t i = 0; i < vec.size(); ++i) { + float total = 0; + + const Map &map = vec[i].first; + Map::const_iterator iter; + for (iter = map.begin(); iter != map.end(); ++iter) { + float count = iter->second; + total += count; + } + + vec[i].second = total; + } +} + +float SpanLengthPhraseProperty::GetProb(size_t ntInd, size_t sourceWidth, float smoothing) const +{ + float count; + + const std::pair &data = m_source[ntInd]; + const Map &map = data.first; + + if (map.size() == 0) { + // should this ever be reached? there shouldn't be any span length proprty so FF shouldn't call this + return 1.0f; + } + + Map::const_iterator iter = map.find(sourceWidth); + if (iter == map.end()) { + count = 0; + } + else { + count = iter->second; + } + count += smoothing; + + float total = data.second + smoothing * (float) map.size(); + float ret = count / total; + return ret; +} + +} diff --git a/moses/PP/SpanLengthPhraseProperty.h b/moses/PP/SpanLengthPhraseProperty.h new file mode 100644 index 000000000..982c3ca0d --- /dev/null +++ b/moses/PP/SpanLengthPhraseProperty.h @@ -0,0 +1,35 @@ + +#pragma once + +#include +#include +#include +#include +#include "moses/PP/PhraseProperty.h" + +namespace Moses +{ + +class SpanLengthPhraseProperty : public PhraseProperty +{ +public: + SpanLengthPhraseProperty(); + + void ProcessValue(const std::string &value); + + float GetProb(size_t ntInd, size_t sourceWidth, float smoothing) const; +protected: + // fractional counts + typedef std::map Map; + typedef std::vector > Vec; + Vec m_source, m_target; + + void Populate(const std::set< std::vector > &indices, float count); + void Populate(const std::vector &toks, float count); + void Populate(Map &map, size_t span, float count); + + void CalcTotals(Vec &vec); +}; + +} // namespace Moses + diff --git a/moses/PP/TreeStructurePhraseProperty.h b/moses/PP/TreeStructurePhraseProperty.h index f9acc38dd..45124973f 100644 --- a/moses/PP/TreeStructurePhraseProperty.h +++ b/moses/PP/TreeStructurePhraseProperty.h @@ -10,7 +10,7 @@ namespace Moses class TreeStructurePhraseProperty : public PhraseProperty { public: - TreeStructurePhraseProperty(const std::string &value) : PhraseProperty(value) {}; + TreeStructurePhraseProperty() {}; }; diff --git a/moses/Parameter.cpp b/moses/Parameter.cpp index b0ac9e811..10ac56627 100644 --- a/moses/Parameter.cpp +++ b/moses/Parameter.cpp @@ -202,8 +202,7 @@ Parameter::Parameter() AddParam("placeholder-factor", "Which source factor to use to store the original text for placeholders. The factor must not be used by a translation or gen model"); AddParam("no-cache", "Disable all phrase-table caching. Default = false (ie. enable caching)"); - - AddParam("adjacent-only", "Only allow hypotheses which are adjacent to current derivation. ITG without block moves"); + AddParam("default-non-term-for-empty-range-only", "Don't add [X] to all ranges, just ranges where there isn't a source non-term. Default = false (ie. add [X] everywhere)"); } diff --git a/moses/Phrase.h b/moses/Phrase.h index 4a5c4828a..f6eb661de 100644 --- a/moses/Phrase.h +++ b/moses/Phrase.h @@ -47,8 +47,8 @@ class WordsRange; class Phrase { friend std::ostream& operator<<(std::ostream&, const Phrase&); -private: - + // private: +protected: std::vector m_words; public: diff --git a/moses/ScoreComponentCollection.cpp b/moses/ScoreComponentCollection.cpp index e252d1a7a..52ec00dd4 100644 --- a/moses/ScoreComponentCollection.cpp +++ b/moses/ScoreComponentCollection.cpp @@ -214,7 +214,7 @@ void ScoreComponentCollection::Save(const string& filename) const void ScoreComponentCollection:: -Assign(const FeatureFunction* sp, const string line) +Assign(const FeatureFunction* sp, const string &line) { istringstream istr(line); while(istr) { diff --git a/moses/ScoreComponentCollection.h b/moses/ScoreComponentCollection.h index 68287296f..3cddbca67 100644 --- a/moses/ScoreComponentCollection.h +++ b/moses/ScoreComponentCollection.h @@ -1,3 +1,4 @@ +// -*- c++ -*- // $Id$ /*********************************************************************** @@ -93,10 +94,13 @@ class ScoreComponentCollection private: FVector m_scores; +public: typedef std::pair IndexPair; +private: typedef std::map ScoreIndexMap; static ScoreIndexMap s_scoreIndexes; static size_t s_denseVectorSize; +public: static IndexPair GetIndexes(const FeatureFunction* sp) { ScoreIndexMap::const_iterator indexIter = s_scoreIndexes.find(sp); if (indexIter == s_scoreIndexes.end()) { @@ -287,7 +291,7 @@ public: //Read sparse features from string - void Assign(const FeatureFunction* sp, const std::string line); + void Assign(const FeatureFunction* sp, const std::string &line); // shortcut: setting the value directly using the feature name void Assign(const std::string name, float score) { diff --git a/moses/ScoreComponentCollectionTest.cpp b/moses/ScoreComponentCollectionTest.cpp index de542d1f6..a238d66b8 100644 --- a/moses/ScoreComponentCollectionTest.cpp +++ b/moses/ScoreComponentCollectionTest.cpp @@ -34,16 +34,16 @@ class MockStatelessFeatureFunction : public StatelessFeatureFunction public: MockStatelessFeatureFunction(size_t n, const string &line) : StatelessFeatureFunction(n, line) {} - void Evaluate(const Hypothesis&, ScoreComponentCollection*) const {} - void EvaluateChart(const ChartHypothesis&, ScoreComponentCollection*) const {} - void Evaluate(const InputType &input + void EvaluateWhenApplied(const Hypothesis&, ScoreComponentCollection*) const {} + void EvaluateWhenApplied(const ChartHypothesis&, ScoreComponentCollection*) const {} + void EvaluateWithSourceContext(const InputType &input , const InputPath &inputPath , const TargetPhrase &targetPhrase , const StackVec *stackVec , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection *estimatedFutureScore) const {} - void Evaluate(const Phrase &source + void EvaluateInIsolation(const Phrase &source , const TargetPhrase &targetPhrase , ScoreComponentCollection &scoreBreakdown , ScoreComponentCollection &estimatedFutureScore) const diff --git a/moses/SearchCubePruning.cpp b/moses/SearchCubePruning.cpp index 49ca22645..b8382eadd 100644 --- a/moses/SearchCubePruning.cpp +++ b/moses/SearchCubePruning.cpp @@ -86,7 +86,7 @@ void SearchCubePruning::ProcessSentence() // go through each stack size_t stackNo = 1; std::vector < HypothesisStack* >::iterator iterStack; - for (iterStack = ++m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) { + for (iterStack = m_hypoStackColl.begin() + 1 ; iterStack != m_hypoStackColl.end() ; ++iterStack) { // check if decoding ran out of time double _elapsed_time = GetUserTime(); if (_elapsed_time > staticData.GetTimeoutThreshold()) { @@ -250,11 +250,6 @@ bool SearchCubePruning::CheckDistortion(const WordsBitmap &hypoBitmap, const Wor return true; } - if (StaticData::Instance().AdjacentOnly() && - !hypoBitmap.IsAdjacent(range.GetStartPos(), range.GetEndPos())) { - return false; - } - bool leftMostEdge = (hypoFirstGapPos == startPos); // any length extension is okay if starting at left-most edge if (leftMostEdge) { diff --git a/moses/SearchNormal.cpp b/moses/SearchNormal.cpp index 8ac0eca13..d40324c15 100644 --- a/moses/SearchNormal.cpp +++ b/moses/SearchNormal.cpp @@ -93,7 +93,9 @@ void SearchNormal::ProcessSentence() // this stack is fully expanded; actual_hypoStack = &sourceHypoColl; + } + //OutputHypoStack(); } @@ -253,11 +255,6 @@ void SearchNormal::ExpandAllHypotheses(const Hypothesis &hypothesis, size_t star expectedScore += m_transOptColl.GetFutureScore().CalcFutureScore( hypothesis.GetWordsBitmap(), startPos, endPos ); } - if (StaticData::Instance().AdjacentOnly() && - !hypothesis.GetWordsBitmap().IsAdjacent(startPos, endPos)) { - return; - } - // loop through all translation options const TranslationOptionList &transOptList = m_transOptColl.GetTranslationOptionList(WordsRange(startPos, endPos)); TranslationOptionList::const_iterator iter; @@ -386,4 +383,15 @@ void SearchNormal::OutputHypoStackSize() TRACE_ERR( endl); } +void SearchNormal::OutputHypoStack() +{ + // all stacks + int i = 0; + vector < HypothesisStack* >::iterator iterStack; + for (iterStack = m_hypoStackColl.begin() ; iterStack != m_hypoStackColl.end() ; ++iterStack) { + HypothesisStackNormal &hypoColl = *static_cast(*iterStack); + TRACE_ERR( "Stack " << i++ << ": " << endl << hypoColl << endl); + } +} + } diff --git a/moses/SearchNormal.h b/moses/SearchNormal.h index 49a0bae9d..d76e102c2 100644 --- a/moses/SearchNormal.h +++ b/moses/SearchNormal.h @@ -39,7 +39,7 @@ public: void ProcessSentence(); void OutputHypoStackSize(); - void OutputHypoStack(int stack); + void OutputHypoStack(); virtual const std::vector < HypothesisStack* >& GetHypothesisStacks() const; virtual const Hypothesis *GetBestHypothesis() const; diff --git a/moses/StaticData.cpp b/moses/StaticData.cpp index 0340778ed..badb189d4 100644 --- a/moses/StaticData.cpp +++ b/moses/StaticData.cpp @@ -388,8 +388,6 @@ bool StaticData::LoadData(Parameter *parameter) SetBooleanParameter( &m_lmEnableOOVFeature, "lmodel-oov-feature", false); - SetBooleanParameter( &m_adjacentOnly, "adjacent-only", false); - // minimum Bayes risk decoding SetBooleanParameter( &m_mbr, "minimum-bayes-risk", false ); m_mbrSize = (m_parameter->GetParam("mbr-size").size() > 0) ? @@ -429,6 +427,9 @@ bool StaticData::LoadData(Parameter *parameter) } if (m_useConsensusDecoding) m_mbr=true; + SetBooleanParameter( &m_defaultNonTermOnlyForEmptyRange, "default-non-term-for-empty-range-only", false ); + + // Compact phrase table and reordering model SetBooleanParameter( &m_minphrMemory, "minphr-memory", false ); SetBooleanParameter( &m_minlexrMemory, "minlexr-memory", false ); @@ -494,7 +495,8 @@ bool StaticData::LoadData(Parameter *parameter) } m_xmlBrackets.first= brackets[0]; m_xmlBrackets.second=brackets[1]; - cerr << "XML tags opening and closing brackets for XML input are: " << m_xmlBrackets.first << " and " << m_xmlBrackets.second << endl; + VERBOSE(1,"XML tags opening and closing brackets for XML input are: " + << m_xmlBrackets.first << " and " << m_xmlBrackets.second << endl); } if (m_parameter->GetParam("placeholder-factor").size() > 0) { @@ -511,7 +513,7 @@ bool StaticData::LoadData(Parameter *parameter) const vector &features = m_parameter->GetParam("feature"); for (size_t i = 0; i < features.size(); ++i) { const string &line = Trim(features[i]); - cerr << "line=" << line << endl; + VERBOSE(1,"line=" << line << endl); if (line.empty()) continue; @@ -535,7 +537,9 @@ bool StaticData::LoadData(Parameter *parameter) NoCache(); OverrideFeatures(); - LoadFeatureFunctions(); + if (!m_parameter->isParamSpecified("show-weights")) { + LoadFeatureFunctions(); + } if (!LoadDecodeGraphs()) return false; @@ -640,7 +644,8 @@ void StaticData::LoadNonTerminals() "Incorrect unknown LHS format: " << line); UnknownLHSEntry entry(tokens[0], Scan(tokens[1])); m_unknownLHS.push_back(entry); - const Factor *targetFactor = factorCollection.AddFactor(Output, 0, tokens[0], true); + // const Factor *targetFactor = + factorCollection.AddFactor(Output, 0, tokens[0], true); } } @@ -734,7 +739,7 @@ bool StaticData::LoadDecodeGraphs() DecodeGraph *decodeGraph; if (IsChart()) { size_t maxChartSpan = (decodeGraphInd < maxChartSpans.size()) ? maxChartSpans[decodeGraphInd] : DEFAULT_MAX_CHART_SPAN; - cerr << "max-chart-span: " << maxChartSpans[decodeGraphInd] << endl; + VERBOSE(1,"max-chart-span: " << maxChartSpans[decodeGraphInd] << endl); decodeGraph = new DecodeGraph(m_decodeGraphs.size(), maxChartSpan); } else { decodeGraph = new DecodeGraph(m_decodeGraphs.size()); @@ -866,7 +871,7 @@ void StaticData::SetExecPath(const std::string &path) if (pos != string::npos) { m_binPath = path.substr(0, pos); } - cerr << m_binPath << endl; + VERBOSE(1,m_binPath << endl); } const string &StaticData::GetBinDirectory() const @@ -920,7 +925,8 @@ void StaticData::LoadFeatureFunctions() FeatureFunction *ff = *iter; bool doLoad = true; - if (PhraseDictionary *ffCast = dynamic_cast(ff)) { + // if (PhraseDictionary *ffCast = dynamic_cast(ff)) { + if (dynamic_cast(ff)) { doLoad = false; } @@ -964,7 +970,7 @@ bool StaticData::CheckWeights() const set::iterator iter; for (iter = weightNames.begin(); iter != weightNames.end(); ) { string fname = (*iter).substr(0, (*iter).find("_")); - cerr << fname << "\n"; + VERBOSE(1,fname << "\n"); if (featureNames.find(fname) != featureNames.end()) { weightNames.erase(iter++); } @@ -1039,7 +1045,7 @@ bool StaticData::LoadAlternateWeightSettings() vector tokens = Tokenize(weightSpecification[i]); vector args = Tokenize(tokens[0], "="); currentId = args[1]; - cerr << "alternate weight setting " << currentId << endl; + VERBOSE(1,"alternate weight setting " << currentId << endl); UTIL_THROW_IF2(m_weightSetting.find(currentId) != m_weightSetting.end(), "Duplicate alternate weight id: " << currentId); m_weightSetting[ currentId ] = new ScoreComponentCollection; diff --git a/moses/StaticData.h b/moses/StaticData.h index 882ac912e..68e1ee60c 100644 --- a/moses/StaticData.h +++ b/moses/StaticData.h @@ -198,7 +198,7 @@ protected: FactorType m_placeHolderFactor; bool m_useLegacyPT; - bool m_adjacentOnly; + bool m_defaultNonTermOnlyForEmptyRange; FeatureRegistry m_registry; PhrasePropertyFactory m_phrasePropertyFactory; @@ -756,13 +756,8 @@ public: } - bool AdjacentOnly() const - { return m_adjacentOnly; } - - void ResetWeights(const std::string &denseWeights, const std::string &sparseFile); - // need global access for output of tree structure const StatefulFeatureFunction* GetTreeStructure() const { return m_treeStructure; @@ -772,6 +767,9 @@ public: m_treeStructure = treeStructure; } + bool GetDefaultNonTermOnlyForEmptyRange() const + { return m_defaultNonTermOnlyForEmptyRange; } + }; } diff --git a/moses/SyntacticLanguageModel.h b/moses/SyntacticLanguageModel.h index 6e88d85c1..76882a4d1 100644 --- a/moses/SyntacticLanguageModel.h +++ b/moses/SyntacticLanguageModel.h @@ -30,7 +30,7 @@ public: const FFState* prev_state, ScoreComponentCollection* accumulator) const; - FFState* EvaluateChart(const ChartHypothesis& cur_hypo, + FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureID, ScoreComponentCollection* accumulator) const { throw std::runtime_error("Syntactic LM can only be used with phrase-based decoder."); diff --git a/moses/TargetPhrase.cpp b/moses/TargetPhrase.cpp index e2a325994..aef4f0fee 100644 --- a/moses/TargetPhrase.cpp +++ b/moses/TargetPhrase.cpp @@ -129,7 +129,7 @@ void TargetPhrase::Evaluate(const Phrase &source, const std::vector &value) const +const PhraseProperty *TargetPhrase::GetProperty(const std::string &key) const { std::map >::const_iterator iter; iter = m_properties.find(key); if (iter != m_properties.end()) { - value = iter->second; - return true; + const boost::shared_ptr &pp = iter->second; + return pp.get(); } - return false; + return NULL; } void TargetPhrase::SetRuleSource(const Phrase &ruleSource) const @@ -288,12 +288,25 @@ std::ostream& operator<<(std::ostream& os, const TargetPhrase& tp) os << tp.GetAlignNonTerm() << flush; os << ": c=" << tp.m_fullScore << flush; os << " " << tp.m_scoreBreakdown << flush; - + const Phrase *sourcePhrase = tp.GetRuleSource(); if (sourcePhrase) { os << " sourcePhrase=" << *sourcePhrase << flush; } + if (tp.m_properties.size()) { + os << " properties: " << flush; + + TargetPhrase::Properties::const_iterator iter; + for (iter = tp.m_properties.begin(); iter != tp.m_properties.end(); ++iter) { + const string &key = iter->first; + const PhraseProperty *prop = iter->second.get(); + assert(prop); + + os << key << "=" << *prop << " "; + } + } + return os; } diff --git a/moses/TargetPhrase.h b/moses/TargetPhrase.h index 1f2fa96dd..1e9e51c79 100644 --- a/moses/TargetPhrase.h +++ b/moses/TargetPhrase.h @@ -57,7 +57,8 @@ private: const Word *m_lhsTarget; mutable Phrase *m_ruleSource; // to be set by the feature function that needs it. - std::map > m_properties; + typedef std::map > Properties; + Properties m_properties; public: TargetPhrase(); @@ -137,7 +138,7 @@ public: void SetProperties(const StringPiece &str); void SetProperty(const std::string &key, const std::string &value); - bool GetProperty(const std::string &key, boost::shared_ptr &value) const; + const PhraseProperty *GetProperty(const std::string &key) const; void Merge(const TargetPhrase ©, const std::vector& factorVec); diff --git a/moses/TargetPhraseCollection.h b/moses/TargetPhraseCollection.h index 47eee0458..0c6a7a74c 100644 --- a/moses/TargetPhraseCollection.h +++ b/moses/TargetPhraseCollection.h @@ -44,6 +44,12 @@ public: typedef CollType::iterator iterator; typedef CollType::const_iterator const_iterator; + TargetPhrase const* + operator[](size_t const i) const + { + return m_collection.at(i); + } + iterator begin() { return m_collection.begin(); } diff --git a/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp b/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp index a83f85e60..163be8937 100644 --- a/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp +++ b/moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerOnDisk.cpp @@ -82,6 +82,8 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( ChartParserCallback &outColl) { const StaticData &staticData = StaticData::Instance(); + const Word &defaultSourceNonTerm = staticData.GetInputDefaultNonTerminal(); + size_t relEndPos = range.GetEndPos() - range.GetStartPos(); size_t absEndPos = range.GetEndPos(); @@ -137,8 +139,6 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( stackInd = relEndPos + 1; } - // size_t nonTermNumWordsCovered = endPos - startPos + 1; - // get target nonterminals in this span from chart const ChartCellLabelSet &chartNonTermSet = GetTargetLabelSet(startPos, endPos); @@ -174,11 +174,21 @@ void ChartRuleLookupManagerOnDisk::GetChartRuleCollection( } const ChartCellLabel &cellLabel = **iterChartNonTerm; - //cerr << sourceLHS << " " << defaultSourceNonTerm << " " << chartNonTerm << " " << defaultTargetNonTerm << endl; + bool doSearch = true; + if (m_dictionary.m_maxSpanDefault != NOT_FOUND) { + // for Hieu's source syntax + const Word &targetLHS = cellLabel.GetLabel(); - //bool isSyntaxNonTerm = (sourceLHS != defaultSourceNonTerm) || (chartNonTerm != defaultTargetNonTerm); - bool doSearch = true; //isSyntaxNonTerm ? nonTermNumWordsCovered <= maxSyntaxSpan : - // nonTermNumWordsCovered <= maxDefaultSpan; + bool isSourceSyntaxNonTerm = sourceLHS != defaultSourceNonTerm; + size_t nonTermNumWordsCovered = endPos - startPos + 1; + + doSearch = isSourceSyntaxNonTerm ? + nonTermNumWordsCovered <= m_dictionary.m_maxSpanLabelled : + nonTermNumWordsCovered <= m_dictionary.m_maxSpanDefault; + + //cerr << "sourceLHS=" << sourceLHS << " targetLHS=" << targetLHS + // << "doSearch=" << doSearch << endl; + } if (doSearch) { diff --git a/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp b/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp index d060ec273..294b93fe2 100644 --- a/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp +++ b/moses/TranslationModel/CYKPlusParser/CompletedRuleCollection.cpp @@ -29,7 +29,7 @@ namespace Moses CompletedRuleCollection::CompletedRuleCollection() : m_ruleLimit(StaticData::Instance().GetRuleLimit()) { - m_scoreThreshold = numeric_limits::infinity(); + m_scoreThreshold = numeric_limits::infinity(); } // copies some functionality (pruning) from ChartTranslationOptionList::Add @@ -37,33 +37,33 @@ void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc, const StackVec &stackVec, const ChartParserCallback &outColl) { - if (tpc.IsEmpty()) { - return; - } + if (tpc.IsEmpty()) { + return; + } - const TargetPhrase &targetPhrase = **(tpc.begin()); - float score = targetPhrase.GetFutureScore(); - for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) { - float stackScore = (*p)->GetBestScore(&outColl); - score += stackScore; - } + const TargetPhrase &targetPhrase = **(tpc.begin()); + float score = targetPhrase.GetFutureScore(); + for (StackVec::const_iterator p = stackVec.begin(); p != stackVec.end(); ++p) { + float stackScore = (*p)->GetBestScore(&outColl); + score += stackScore; + } - // If the rule limit has already been reached then don't add the option - // unless it is better than at least one existing option. - if (m_collection.size() > m_ruleLimit && score < m_scoreThreshold) { - return; - } + // If the rule limit has already been reached then don't add the option + // unless it is better than at least one existing option. + if (m_ruleLimit && m_collection.size() > m_ruleLimit && score < m_scoreThreshold) { + return; + } - CompletedRule *completedRule = new CompletedRule(tpc, stackVec, score); - m_collection.push_back(completedRule); + CompletedRule *completedRule = new CompletedRule(tpc, stackVec, score); + m_collection.push_back(completedRule); // If the rule limit hasn't been exceeded then update the threshold. - if (m_collection.size() <= m_ruleLimit) { + if (!m_ruleLimit || m_collection.size() <= m_ruleLimit) { m_scoreThreshold = (score < m_scoreThreshold) ? score : m_scoreThreshold; } // Prune if bursting - if (m_collection.size() == m_ruleLimit * 2) { + if (m_ruleLimit && m_collection.size() == m_ruleLimit * 2) { NTH_ELEMENT4(m_collection.begin(), m_collection.begin() + m_ruleLimit - 1, m_collection.end(), @@ -77,4 +77,4 @@ void CompletedRuleCollection::Add(const TargetPhraseCollection &tpc, } } -} \ No newline at end of file +} diff --git a/moses/TranslationModel/PhraseDictionaryMultiModel.cpp b/moses/TranslationModel/PhraseDictionaryMultiModel.cpp index 9f3996505..a1824b475 100644 --- a/moses/TranslationModel/PhraseDictionaryMultiModel.cpp +++ b/moses/TranslationModel/PhraseDictionaryMultiModel.cpp @@ -323,9 +323,6 @@ void PhraseDictionaryMultiModel::SetTemporaryMultiModelWeightsVector(std::vector vector PhraseDictionaryMultiModel::MinimizePerplexity(vector > &phrase_pair_vector) { - const StaticData &staticData = StaticData::Instance(); - const string& factorDelimiter = staticData.GetFactorDelimiter(); - map, size_t> phrase_pair_map; for ( vector >::const_iterator iter = phrase_pair_vector.begin(); iter != phrase_pair_vector.end(); ++iter ) { @@ -344,7 +341,7 @@ vector PhraseDictionaryMultiModel::MinimizePerplexity(vector* allStats = new(map); Phrase sourcePhrase(0); - sourcePhrase.CreateFromString(Input, m_input, source_string, factorDelimiter, NULL); + sourcePhrase.CreateFromString(Input, m_input, source_string, NULL); CollectSufficientStatistics(sourcePhrase, allStats); //optimization potential: only call this once per source phrase diff --git a/moses/TranslationModel/PhraseDictionaryMultiModelCounts.cpp b/moses/TranslationModel/PhraseDictionaryMultiModelCounts.cpp index 04bb321d0..83aa4a718 100644 --- a/moses/TranslationModel/PhraseDictionaryMultiModelCounts.cpp +++ b/moses/TranslationModel/PhraseDictionaryMultiModelCounts.cpp @@ -17,12 +17,8 @@ License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA ***********************************************************************/ #include "util/exception.hh" - #include "moses/TranslationModel/PhraseDictionaryMultiModelCounts.h" -#define LINE_MAX_LENGTH 100000 -#include "phrase-extract/SafeGetline.h" // for SAFE_GETLINE() - using namespace std; template @@ -461,16 +457,14 @@ void PhraseDictionaryMultiModelCounts::LoadLexicalTable( string &fileName, lexic } istream *inFileP = &inFile; - char line[LINE_MAX_LENGTH]; - int i=0; - while(true) { + string line; + + while(getline(*inFileP, line)) { i++; if (i%100000 == 0) cerr << "." << flush; - SAFE_GETLINE((*inFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (inFileP->eof()) break; - vector token = tokenize( line ); + vector token = tokenize( line.c_str() ); if (token.size() != 4) { cerr << "line " << i << " in " << fileName << " has wrong number of tokens, skipping:\n" @@ -495,9 +489,6 @@ void PhraseDictionaryMultiModelCounts::LoadLexicalTable( string &fileName, lexic vector PhraseDictionaryMultiModelCounts::MinimizePerplexity(vector > &phrase_pair_vector) { - const StaticData &staticData = StaticData::Instance(); - const string& factorDelimiter = staticData.GetFactorDelimiter(); - map, size_t> phrase_pair_map; for ( vector >::const_iterator iter = phrase_pair_vector.begin(); iter != phrase_pair_vector.end(); ++iter ) { @@ -516,7 +507,7 @@ vector PhraseDictionaryMultiModelCounts::MinimizePerplexity(vector* allStats = new(map); Phrase sourcePhrase(0); - sourcePhrase.CreateFromString(Input, m_input, source_string, factorDelimiter, NULL); + sourcePhrase.CreateFromString(Input, m_input, source_string, NULL); CollectSufficientStatistics(sourcePhrase, fs, allStats); //optimization potential: only call this once per source phrase diff --git a/moses/TranslationModel/PhraseDictionaryTree.cpp b/moses/TranslationModel/PhraseDictionaryTree.cpp index 68dd5a59f..c8b7cb5d2 100644 --- a/moses/TranslationModel/PhraseDictionaryTree.cpp +++ b/moses/TranslationModel/PhraseDictionaryTree.cpp @@ -3,6 +3,7 @@ #include "moses/FeatureVector.h" #include "moses/TranslationModel/PhraseDictionaryTree.h" #include "util/exception.hh" +#include "moses/StaticData.h" #include #include @@ -233,7 +234,8 @@ public: typedef PhraseDictionaryTree::PrefixPtr PPtr; void GetTargetCandidates(PPtr p,TgtCands& tgtCands) { - UTIL_THROW_IF2(p == NULL, "Error"); + UTIL_THROW_IF2(p == 0L, "Error"); + // UTIL_THROW_IF2(p == NULL, "Error"); if(p.imp->isRoot()) return; OFF_T tCandOffset=p.imp->ptr()->getData(p.imp->idx); @@ -278,7 +280,8 @@ public: } PPtr Extend(PPtr p,const std::string& w) { - UTIL_THROW_IF2(p == NULL, "Error"); + UTIL_THROW_IF2(p == 0L, "Error"); + // UTIL_THROW_IF2(p == NULL, "Error"); if(w.empty() || w==EPSILON) return p; @@ -349,8 +352,8 @@ int PDTimp::Read(const std::string& fn) sv.Read(ifsv); tv.Read(iftv); - TRACE_ERR("binary phrasefile loaded, default OFF_T: "<Read(fn); } diff --git a/moses/TranslationModel/ProbingPT/Jamfile b/moses/TranslationModel/ProbingPT/Jamfile new file mode 100644 index 000000000..d30ae3486 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/Jamfile @@ -0,0 +1,13 @@ +local current = "" ; +local includes = ; +if [ option.get "with-probing-pt" : : "yes" ] +{ + fakelib ProbingPT : [ glob *.cpp ] ../..//headers : $(includes) $(PT-LOG) : : $(includes) ; +} +else { + fakelib ProbingPT ; +} + +path-constant PT-LOG : bin/pt.log ; +update-if-changed $(PT-LOG) $(current) ; + diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.cpp b/moses/TranslationModel/ProbingPT/ProbingPT.cpp new file mode 100644 index 000000000..9859520c1 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/ProbingPT.cpp @@ -0,0 +1,231 @@ +// vim:tabstop=2 +#include "ProbingPT.h" +#include "moses/StaticData.h" +#include "moses/FactorCollection.h" +#include "moses/TranslationModel/CYKPlusParser/ChartRuleLookupManagerSkeleton.h" +#include "quering.hh" + +using namespace std; + +namespace Moses +{ +ProbingPT::ProbingPT(const std::string &line) +: PhraseDictionary(line) +,m_engine(NULL) +{ + ReadParameters(); + + assert(m_input.size() == 1); + assert(m_output.size() == 1); +} + +ProbingPT::~ProbingPT() +{ + delete m_engine; +} + +void ProbingPT::Load() +{ + SetFeaturesToApply(); + + m_engine = new QueryEngine(m_filePath.c_str()); + + m_unkId = 456456546456; + + // source vocab + const std::map &sourceVocab = m_engine->getSourceVocab(); + std::map::const_iterator iterSource; + for (iterSource = sourceVocab.begin(); iterSource != sourceVocab.end(); ++iterSource) { + const string &wordStr = iterSource->second; + const Factor *factor = FactorCollection::Instance().AddFactor(wordStr); + + uint64_t probingId = iterSource->first; + + SourceVocabMap::value_type entry(factor, probingId); + m_sourceVocabMap.insert(entry); + + } + + // target vocab + const std::map &probingVocab = m_engine->getVocab(); + std::map::const_iterator iter; + for (iter = probingVocab.begin(); iter != probingVocab.end(); ++iter) { + const string &wordStr = iter->second; + const Factor *factor = FactorCollection::Instance().AddFactor(wordStr); + + unsigned int probingId = iter->first; + + TargetVocabMap::value_type entry(factor, probingId); + m_vocabMap.insert(entry); + + } +} + +void ProbingPT::InitializeForInput(InputType const& source) +{ + ReduceCache(); +} + +void ProbingPT::GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const +{ + CacheColl &cache = GetCache(); + + InputPathList::const_iterator iter; + for (iter = inputPathQueue.begin(); iter != inputPathQueue.end(); ++iter) { + InputPath &inputPath = **iter; + const Phrase &sourcePhrase = inputPath.GetPhrase(); + + if (sourcePhrase.GetSize() > StaticData::Instance().GetMaxPhraseLength()) { + continue; + } + + TargetPhraseCollection *tpColl = CreateTargetPhrase(sourcePhrase); + + // add target phrase to phrase-table cache + size_t hash = hash_value(sourcePhrase); + std::pair value(tpColl, clock()); + cache[hash] = value; + + inputPath.SetTargetPhrases(*this, tpColl, NULL); + } +} + +std::vector ProbingPT::ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const +{ + size_t size = sourcePhrase.GetSize(); + std::vector ret(size); + for (size_t i = 0; i < size; ++i) { + const Factor *factor = sourcePhrase.GetFactor(i, m_input[0]); + uint64_t probingId = GetSourceProbingId(factor); + if (probingId == m_unkId) { + ok = false; + return ret; + } + else { + ret[i] = probingId; + } + } + + ok = true; + return ret; +} + +TargetPhraseCollection *ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase) const +{ + // create a target phrase from the 1st word of the source, prefix with 'ProbingPT:' + assert(sourcePhrase.GetSize()); + + bool ok; + vector probingSource = ConvertToProbingSourcePhrase(sourcePhrase, ok); + if (!ok) { + // source phrase contains a word unknown in the pt. + // We know immediately there's no translation for it + return NULL; + } + + std::pair > query_result; + + TargetPhraseCollection *tpColl = NULL; + + //Actual lookup + query_result = m_engine->query(probingSource); + + if (query_result.first) { + //m_engine->printTargetInfo(query_result.second); + tpColl = new TargetPhraseCollection(); + + const std::vector &probingTargetPhrases = query_result.second; + for (size_t i = 0; i < probingTargetPhrases.size(); ++i) { + const target_text &probingTargetPhrase = probingTargetPhrases[i]; + TargetPhrase *tp = CreateTargetPhrase(sourcePhrase, probingTargetPhrase); + + tpColl->Add(tp); + } + + tpColl->Prune(true, m_tableLimit); + } + + return tpColl; +} + +TargetPhrase *ProbingPT::CreateTargetPhrase(const Phrase &sourcePhrase, const target_text &probingTargetPhrase) const +{ + const std::vector &probingPhrase = probingTargetPhrase.target_phrase; + size_t size = probingPhrase.size(); + + TargetPhrase *tp = new TargetPhrase(); + + // words + for (size_t i = 0; i < size; ++i) { + uint64_t probingId = probingPhrase[i]; + const Factor *factor = GetTargetFactor(probingId); + assert(factor); + + Word &word = tp->AddWord(); + word.SetFactor(m_output[0], factor); + } + + // score for this phrase table + vector scores = probingTargetPhrase.prob; + std::transform(scores.begin(), scores.end(), scores.begin(),TransformScore); + tp->GetScoreBreakdown().PlusEquals(this, scores); + + // alignment + /* + const std::vector &alignments = probingTargetPhrase.word_all1; + + AlignmentInfo &aligns = tp->GetAlignTerm(); + for (size_t i = 0; i < alignS.size(); i += 2 ) { + aligns.Add((size_t) alignments[i], (size_t) alignments[i+1]); + } + */ + + // score of all other ff when this rule is being loaded + tp->Evaluate(sourcePhrase, GetFeaturesToApply()); + return tp; +} + +const Factor *ProbingPT::GetTargetFactor(uint64_t probingId) const +{ + TargetVocabMap::right_map::const_iterator iter; + iter = m_vocabMap.right.find(probingId); + if (iter != m_vocabMap.right.end()) { + return iter->second; + } + else { + // not in mapping. Must be UNK + return NULL; + } +} + +uint64_t ProbingPT::GetSourceProbingId(const Factor *factor) const +{ + SourceVocabMap::left_map::const_iterator iter; + iter = m_sourceVocabMap.left.find(factor); + if (iter != m_sourceVocabMap.left.end()) { + return iter->second; + } + else { + // not in mapping. Must be UNK + return m_unkId; + } +} + +ChartRuleLookupManager *ProbingPT::CreateRuleLookupManager( + const ChartParser &, + const ChartCellCollectionBase &, + std::size_t) +{ + abort(); + return NULL; +} + +TO_STRING_BODY(ProbingPT); + +// friend +ostream& operator<<(ostream& out, const ProbingPT& phraseDict) +{ + return out; +} + +} diff --git a/moses/TranslationModel/ProbingPT/ProbingPT.h b/moses/TranslationModel/ProbingPT/ProbingPT.h new file mode 100644 index 000000000..b879760cb --- /dev/null +++ b/moses/TranslationModel/ProbingPT/ProbingPT.h @@ -0,0 +1,59 @@ + +#pragma once + +#include +#include "../PhraseDictionary.h" + +class QueryEngine; +class target_text; + +namespace Moses +{ +class ChartParser; +class ChartCellCollectionBase; +class ChartRuleLookupManager; + +class ProbingPT : public PhraseDictionary +{ + friend std::ostream& operator<<(std::ostream&, const ProbingPT&); + +public: + ProbingPT(const std::string &line); + ~ProbingPT(); + + void Load(); + + void InitializeForInput(InputType const& source); + + // for phrase-based model + void GetTargetPhraseCollectionBatch(const InputPathList &inputPathQueue) const; + + // for syntax/hiero model (CKY+ decoding) + virtual ChartRuleLookupManager *CreateRuleLookupManager( + const ChartParser &, + const ChartCellCollectionBase &, + std::size_t); + + TO_STRING(); + + +protected: + QueryEngine *m_engine; + + typedef boost::bimap SourceVocabMap; + mutable SourceVocabMap m_sourceVocabMap; + + typedef boost::bimap TargetVocabMap; + mutable TargetVocabMap m_vocabMap; + + TargetPhraseCollection *CreateTargetPhrase(const Phrase &sourcePhrase) const; + TargetPhrase *CreateTargetPhrase(const Phrase &sourcePhrase, const target_text &probingTargetPhrase) const; + const Factor *GetTargetFactor(uint64_t probingId) const; + uint64_t GetSourceProbingId(const Factor *factor) const; + + std::vector ConvertToProbingSourcePhrase(const Phrase &sourcePhrase, bool &ok) const; + + uint64_t m_unkId; +}; + +} // namespace Moses diff --git a/moses/TranslationModel/ProbingPT/hash.cpp b/moses/TranslationModel/ProbingPT/hash.cpp new file mode 100644 index 000000000..1049292b1 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/hash.cpp @@ -0,0 +1,27 @@ +#include "hash.hh" + +uint64_t getHash(StringPiece text) { + std::size_t len = text.size(); + uint64_t key = util::MurmurHashNative(text.data(), len); + return key; +} + +std::vector getVocabIDs(StringPiece textin){ + //Tokenize + std::vector output; + + util::TokenIter it(textin, util::SingleCharacter(' ')); + + while(it){ + output.push_back(getHash(*it)); + it++; + } + + return output; +} + +uint64_t getVocabID(std::string candidate) { + std::size_t len = candidate.length(); + uint64_t key = util::MurmurHashNative(candidate.c_str(), len); + return key; +} \ No newline at end of file diff --git a/moses/TranslationModel/ProbingPT/hash.hh b/moses/TranslationModel/ProbingPT/hash.hh new file mode 100644 index 000000000..a4fcd6330 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/hash.hh @@ -0,0 +1,14 @@ +#pragma once + +#include "util/string_piece.hh" +#include "util/murmur_hash.hh" +#include "util/string_piece.hh" //Tokenization and work with StringPiece +#include "util/tokenize_piece.hh" +#include + +//Gets the MurmurmurHash for give string +uint64_t getHash(StringPiece text); + +std::vector getVocabIDs(StringPiece textin); + +uint64_t getVocabID(std::string candidate); \ No newline at end of file diff --git a/moses/TranslationModel/ProbingPT/huffmanish.cpp b/moses/TranslationModel/ProbingPT/huffmanish.cpp new file mode 100644 index 000000000..eea0a7c53 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/huffmanish.cpp @@ -0,0 +1,414 @@ +#include "huffmanish.hh" + +Huffman::Huffman (const char * filepath) { + //Read the file + util::FilePiece filein(filepath); + + //Init uniq_lines to zero; + uniq_lines = 0; + + line_text prev_line; //Check for unique lines. + int num_lines = 0 ; + + while (true){ + line_text new_line; + + num_lines++; + + try { + //Process line read + new_line = splitLine(filein.ReadLine()); + count_elements(new_line); //Counts the number of elements, adds new and increments counters. + + } catch (util::EndOfFileException e){ + std::cerr << "Unique entries counted: "; + break; + } + + if (new_line.source_phrase == prev_line.source_phrase){ + continue; + } else { + uniq_lines++; + prev_line = new_line; + } + } + + std::cerr << uniq_lines << std::endl; +} + +void Huffman::count_elements(line_text linein){ + //For target phrase: + util::TokenIter it(linein.target_phrase, util::SingleCharacter(' ')); + while (it) { + //Check if we have that entry + std::map::iterator mapiter; + mapiter = target_phrase_words.find(it->as_string()); + + if (mapiter != target_phrase_words.end()){ + //If the element is found, increment the count. + mapiter->second++; + } else { + //Else create a new entry; + target_phrase_words.insert(std::pair(it->as_string(), 1)); + } + it++; + } + + //For word allignment 1 + std::map, unsigned int>::iterator mapiter3; + std::vector numbers = splitWordAll1(linein.word_all1); + mapiter3 = word_all1.find(numbers); + + if (mapiter3 != word_all1.end()){ + //If the element is found, increment the count. + mapiter3->second++; + } else { + //Else create a new entry; + word_all1.insert(std::pair, unsigned int>(numbers, 1)); + } + +} + +//Assigns huffman values for each unique element +void Huffman::assign_values() { + //First create vectors for all maps so that we could sort them later. + + //Create a vector for target phrases + for(std::map::iterator it = target_phrase_words.begin(); it != target_phrase_words.end(); it++ ) { + target_phrase_words_counts.push_back(*it); + } + //Sort it + std::sort(target_phrase_words_counts.begin(), target_phrase_words_counts.end(), sort_pair()); + + //Create a vector for word allignments 1 + for(std::map, unsigned int>::iterator it = word_all1.begin(); it != word_all1.end(); it++ ) { + word_all1_counts.push_back(*it); + } + //Sort it + std::sort(word_all1_counts.begin(), word_all1_counts.end(), sort_pair_vec()); + + + //Afterwards we assign a value for each phrase, starting from 1, as zero is reserved for delimiter + unsigned int i = 1; //huffman code + for(std::vector >::iterator it = target_phrase_words_counts.begin(); + it != target_phrase_words_counts.end(); it++){ + target_phrase_huffman.insert(std::pair(it->first, i)); + i++; //Go to the next huffman code + } + + i = 1; //Reset i for the next map + for(std::vector, unsigned int> >::iterator it = word_all1_counts.begin(); + it != word_all1_counts.end(); it++){ + word_all1_huffman.insert(std::pair, unsigned int>(it->first, i)); + i++; //Go to the next huffman code + } + + //After lookups are produced, clear some memory usage of objects not needed anymore. + target_phrase_words.clear(); + word_all1.clear(); + + target_phrase_words_counts.clear(); + word_all1_counts.clear(); + + std::cerr << "Finished generating huffman codes." << std::endl; + +} + +void Huffman::serialize_maps(const char * dirname){ + //Note that directory name should exist. + std::string basedir(dirname); + std::string target_phrase_path(basedir + "/target_phrases"); + std::string probabilities_path(basedir + "/probs"); + std::string word_all1_path(basedir + "/Wall1"); + + //Target phrase + std::ofstream os (target_phrase_path.c_str(), std::ios::binary); + boost::archive::text_oarchive oarch(os); + oarch << lookup_target_phrase; + os.close(); + + //Word all1 + std::ofstream os2 (word_all1_path.c_str(), std::ios::binary); + boost::archive::text_oarchive oarch2(os2); + oarch2 << lookup_word_all1; + os2.close(); +} + +std::vector Huffman::full_encode_line(line_text line){ + return vbyte_encode_line((encode_line(line))); +} + +std::vector Huffman::encode_line(line_text line){ + std::vector retvector; + + //Get target_phrase first. + util::TokenIter it(line.target_phrase, util::SingleCharacter(' ')); + while (it) { + retvector.push_back(target_phrase_huffman.find(it->as_string())->second); + it++; + } + //Add a zero; + retvector.push_back(0); + + //Get probabilities. Reinterpreting the float bytes as unsgined int. + util::TokenIter probit(line.prob, util::SingleCharacter(' ')); + while (probit) { + //Sometimes we have too big floats to handle, so first convert to double + double tempnum = atof(probit->data()); + float num = (float)tempnum; + retvector.push_back(reinterpret_float(&num)); + probit++; + } + //Add a zero; + retvector.push_back(0); + + + //Get Word allignments + retvector.push_back(word_all1_huffman.find(splitWordAll1(line.word_all1))->second); + retvector.push_back(0); + + return retvector; +} + +void Huffman::produce_lookups(){ + //basically invert every map that we have + for(std::map::iterator it = target_phrase_huffman.begin(); it != target_phrase_huffman.end(); it++ ) { + lookup_target_phrase.insert(std::pair(it->second, it->first)); + } + + for(std::map, unsigned int>::iterator it = word_all1_huffman.begin(); it != word_all1_huffman.end(); it++ ) { + lookup_word_all1.insert(std::pair >(it->second, it->first)); + } + +} + +HuffmanDecoder::HuffmanDecoder (const char * dirname){ + //Read the maps from disk + + //Note that directory name should exist. + std::string basedir(dirname); + std::string target_phrase_path(basedir + "/target_phrases"); + std::string word_all1_path(basedir + "/Wall1"); + + //Target phrases + std::ifstream is (target_phrase_path.c_str(), std::ios::binary); + boost::archive::text_iarchive iarch(is); + iarch >> lookup_target_phrase; + is.close(); + + //Word allignment 1 + std::ifstream is2 (word_all1_path.c_str(), std::ios::binary); + boost::archive::text_iarchive iarch2(is2); + iarch2 >> lookup_word_all1; + is2.close(); + +} + +HuffmanDecoder::HuffmanDecoder (std::map * lookup_target, + std::map > * lookup_word1) { + lookup_target_phrase = *lookup_target; + lookup_word_all1 = *lookup_word1; +} + +std::vector HuffmanDecoder::full_decode_line (std::vector lines){ + std::vector retvector; //All target phrases + std::vector decoded_lines = vbyte_decode_line(lines); //All decoded lines + std::vector::iterator it = decoded_lines.begin(); //Iterator for them + std::vector current_target_phrase; //Current target phrase decoded + + short zero_count = 0; //Count home many zeroes we have met. so far. Every 3 zeroes mean a new target phrase. + while(it != decoded_lines.end()){ + if (zero_count == 3) { + //We have finished with this entry, decode it, and add it to the retvector. + retvector.push_back(decode_line(current_target_phrase)); + current_target_phrase.clear(); //Clear the current target phrase and the zero_count + zero_count = 0; //So that we can reuse them for the next target phrase + } + //Add to the next target_phrase, number by number. + current_target_phrase.push_back(*it); + if (*it == 0) { + zero_count++; + } + it++; //Go to the next word/symbol + } + //Don't forget the last remaining line! + if (zero_count == 3) { + //We have finished with this entry, decode it, and add it to the retvector. + retvector.push_back(decode_line(current_target_phrase)); + current_target_phrase.clear(); //Clear the current target phrase and the zero_count + zero_count = 0; //So that we can reuse them for the next target phrase + } + + return retvector; + +} + +target_text HuffmanDecoder::decode_line (std::vector input){ + //demo decoder + target_text ret; + //Split everything + std::vector target_phrase; + std::vector probs; + unsigned int wAll; + + //Split the line into the proper arrays + short num_zeroes = 0; + int counter = 0; + while (num_zeroes < 3){ + unsigned int num = input[counter]; + if (num == 0) { + num_zeroes++; + } else if (num_zeroes == 0){ + target_phrase.push_back(num); + } else if (num_zeroes == 1){ + probs.push_back(num); + } else if (num_zeroes == 2){ + wAll = num; + } + counter++; + } + + ret.target_phrase = target_phrase; + ret.word_all1 = lookup_word_all1.find(wAll)->second; + + //Decode probabilities + for (std::vector::iterator it = probs.begin(); it != probs.end(); it++){ + ret.prob.push_back(reinterpret_uint(&(*it))); + } + + return ret; + +} + +inline std::string HuffmanDecoder::getTargetWordFromID(unsigned int id) { + return lookup_target_phrase.find(id)->second; +} + +std::string HuffmanDecoder::getTargetWordsFromIDs(std::vector ids){ + std::string returnstring; + for (std::vector::iterator it = ids.begin(); it != ids.end(); it++){ + returnstring.append(getTargetWordFromID(*it) + " "); + } + + return returnstring; +} + +inline std::string getTargetWordFromID(unsigned int id, std::map * lookup_target_phrase) { + return lookup_target_phrase->find(id)->second; +} + +std::string getTargetWordsFromIDs(std::vector ids, std::map * lookup_target_phrase) { + std::string returnstring; + for (std::vector::iterator it = ids.begin(); it != ids.end(); it++){ + returnstring.append(getTargetWordFromID(*it, lookup_target_phrase) + " "); + } + + return returnstring; +} + +/*Those functions are used to more easily store the floats in the binary phrase table + We convert the float unsinged int so that it is the same as our other values and we can + apply variable byte encoding on top of it.*/ + +inline unsigned int reinterpret_float(float * num){ + unsigned int * converted_num; + converted_num = reinterpret_cast(num); + return *converted_num; +} + +inline float reinterpret_uint(unsigned int * num){ + float * converted_num; + converted_num = reinterpret_cast(num); + return *converted_num; +} + +/*Mostly taken from stackoverflow, http://stackoverflow.com/questions/5858646/optimizing-variable-length-encoding +and modified in order to return a vector of chars. Implements ULEB128 or variable byte encoding. +This is highly optimized version with unrolled loop */ +inline std::vector vbyte_encode(unsigned int num){ + //Determine how many bytes we are going to take. + short size; + std::vector byte_vector; + + if (num < 0x00000080U) { + size = 1; + byte_vector.reserve(size); + goto b1; + } + if (num < 0x00004000U) { + size = 2; + byte_vector.reserve(size); + goto b2; + } + if (num < 0x00200000U) { + size = 3; + byte_vector.reserve(size); + goto b3; + } + if (num < 0x10000000U) { + size = 4; + byte_vector.reserve(size); + goto b4; + } + size = 5; + byte_vector.reserve(size); + + + //Now proceed with the encoding. + byte_vector.push_back((num & 0x7f) | 0x80); + num >>= 7; +b4: + byte_vector.push_back((num & 0x7f) | 0x80); + num >>= 7; +b3: + byte_vector.push_back((num & 0x7f) | 0x80); + num >>= 7; +b2: + byte_vector.push_back((num & 0x7f) | 0x80); + num >>= 7; +b1: + byte_vector.push_back(num); + + return byte_vector; +} + +std::vector vbyte_decode_line(std::vector line){ + std::vector huffman_line; + std::vector current_num; + + for (std::vector::iterator it = line.begin(); it != line.end(); it++){ + current_num.push_back(*it); + if ((*it >> 7) != 1) { + //We don't have continuation in the next bit + huffman_line.push_back(bytes_to_int(current_num)); + current_num.clear(); + } + } + return huffman_line; +} + +inline unsigned int bytes_to_int(std::vector number){ + unsigned int retvalue = 0; + std::vector::iterator it = number.begin(); + unsigned char shift = 0; //By how many bits to shift + + while (it != number.end()) { + retvalue |= (*it & 0x7f) << shift; + shift += 7; + it++; + } + + return retvalue; +} + +std::vector vbyte_encode_line(std::vector line) { + std::vector retvec; + + //For each unsigned int in the line, vbyte encode it and add it to a vector of unsigned chars. + for (std::vector::iterator it = line.begin(); it != line.end(); it++){ + std::vector vbyte_encoded = vbyte_encode(*it); + retvec.insert(retvec.end(), vbyte_encoded.begin(), vbyte_encoded.end()); + } + + return retvec; +} diff --git a/moses/TranslationModel/ProbingPT/huffmanish.hh b/moses/TranslationModel/ProbingPT/huffmanish.hh new file mode 100644 index 000000000..3116484e9 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/huffmanish.hh @@ -0,0 +1,110 @@ +#pragma once + +//Huffman encodes a line and also produces the vocabulary ids +#include "hash.hh" +#include "line_splitter.hh" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//Sorting for the second +struct sort_pair { + bool operator()(const std::pair &left, const std::pair &right) { + return left.second > right.second; //This puts biggest numbers first. + } +}; + +struct sort_pair_vec { + bool operator()(const std::pair, unsigned int> &left, const std::pair, unsigned int> &right) { + return left.second > right.second; //This puts biggest numbers first. + } +}; + +class Huffman { + unsigned long uniq_lines; //Unique lines in the file. + + //Containers used when counting the occurence of a given phrase + std::map target_phrase_words; + std::map, unsigned int> word_all1; + + //Same containers as vectors, for sorting + std::vector > target_phrase_words_counts; + std::vector, unsigned int> > word_all1_counts; + + //Huffman maps + std::map target_phrase_huffman; + std::map, unsigned int> word_all1_huffman; + + //inverted maps + std::map lookup_target_phrase; + std::map > lookup_word_all1; + + public: + Huffman (const char *); + void count_elements (line_text line); + void assign_values(); + void serialize_maps(const char * dirname); + void produce_lookups(); + + std::vector encode_line(line_text line); + + //encode line + variable byte ontop + std::vector full_encode_line(line_text line); + + //Getters + const std::map get_target_lookup_map() const{ + return lookup_target_phrase; + } + const std::map > get_word_all1_lookup_map() const{ + return lookup_word_all1; + } + + unsigned long getUniqLines() { + return uniq_lines; + } +}; + +class HuffmanDecoder { + std::map lookup_target_phrase; + std::map > lookup_word_all1; + +public: + HuffmanDecoder (const char *); + HuffmanDecoder (std::map *, std::map > *); + + //Getters + const std::map get_target_lookup_map() const{ + return lookup_target_phrase; + } + const std::map > get_word_all1_lookup_map() const{ + return lookup_word_all1; + } + + inline std::string getTargetWordFromID(unsigned int id); + + std::string getTargetWordsFromIDs(std::vector ids); + + target_text decode_line (std::vector input); + + //Variable byte decodes a all target phrases contained here and then passes them to decode_line + std::vector full_decode_line (std::vector lines); +}; + +std::string getTargetWordsFromIDs(std::vector ids, std::map * lookup_target_phrase); + +inline std::string getTargetWordFromID(unsigned int id, std::map * lookup_target_phrase); + +inline unsigned int reinterpret_float(float * num); + +inline float reinterpret_uint(unsigned int * num); + +std::vector vbyte_encode_line(std::vector line); +inline std::vector vbyte_encode(unsigned int num); +std::vector vbyte_decode_line(std::vector line); +inline unsigned int bytes_to_int(std::vector number); diff --git a/moses/TranslationModel/ProbingPT/line_splitter.cpp b/moses/TranslationModel/ProbingPT/line_splitter.cpp new file mode 100644 index 000000000..f50090e4c --- /dev/null +++ b/moses/TranslationModel/ProbingPT/line_splitter.cpp @@ -0,0 +1,52 @@ +#include "line_splitter.hh" + +line_text splitLine(StringPiece textin) { + const char delim[] = " ||| "; + line_text output; + + //Tokenize + util::TokenIter it(textin, util::MultiCharacter(delim)); + //Get source phrase + output.source_phrase = *it; + it++; + //Get target_phrase + output.target_phrase = *it; + it++; + //Get probabilities + output.prob = *it; + it++; + //Get WordAllignment 1 + output.word_all1 = *it; + it++; + //Get WordAllignment 2 + output.word_all2 = *it; + + return output; +} + +std::vector splitWordAll1(StringPiece textin){ + const char delim[] = " "; + const char delim2[] = "-"; + std::vector output; + + //Split on space + util::TokenIter it(textin, util::MultiCharacter(delim)); + + //For each int + while (it) { + //Split on dash (-) + util::TokenIter itInner(*it, util::MultiCharacter(delim2)); + + //Insert the two entries in the vector. User will read entry 0 and 1 to get the first, + //2 and 3 for second etc. Use unsigned char instead of int to save space, as + //word allignments are all very small numbers that fit in a single byte + output.push_back((unsigned char)(atoi(itInner->data()))); + itInner++; + output.push_back((unsigned char)(atoi(itInner->data()))); + it++; + } + + return output; + +} + diff --git a/moses/TranslationModel/ProbingPT/line_splitter.hh b/moses/TranslationModel/ProbingPT/line_splitter.hh new file mode 100644 index 000000000..c699a28c0 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/line_splitter.hh @@ -0,0 +1,31 @@ +#pragma once + +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/file_piece.hh" +#include +#include //atof +#include "util/string_piece.hh" //Tokenization and work with StringPiece +#include "util/tokenize_piece.hh" +#include + +//Struct for holding processed line +struct line_text { + StringPiece source_phrase; + StringPiece target_phrase; + StringPiece prob; + StringPiece word_all1; + StringPiece word_all2; +}; + +//Struct for holding processed line +struct target_text { + std::vector target_phrase; + std::vector prob; + std::vector word_all1; +}; + +//Ask if it's better to have it receive a pointer to a line_text struct +line_text splitLine(StringPiece textin); + +std::vector splitWordAll1(StringPiece textin); diff --git a/moses/TranslationModel/ProbingPT/probing_hash_utils.cpp b/moses/TranslationModel/ProbingPT/probing_hash_utils.cpp new file mode 100644 index 000000000..35cb9e538 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/probing_hash_utils.cpp @@ -0,0 +1,32 @@ +#include "probing_hash_utils.hh" + +//Read table from disk, return memory map location +char * readTable(const char * filename, size_t size) { + //Initial position of the file is the end of the file, thus we know the size + int fd; + char * map; + + fd = open(filename, O_RDONLY); + if (fd == -1) { + perror("Error opening file for reading"); + exit(EXIT_FAILURE); + } + + map = (char *)mmap(0, size, PROT_READ, MAP_SHARED, fd, 0); + + if (map == MAP_FAILED) { + close(fd); + perror("Error mmapping the file"); + exit(EXIT_FAILURE); + } + + return map; +} + + +void serialize_table(char *mem, size_t size, const char * filename){ + std::ofstream os (filename, std::ios::binary); + os.write((const char*)&mem[0], size); + os.close(); + +} \ No newline at end of file diff --git a/moses/TranslationModel/ProbingPT/probing_hash_utils.hh b/moses/TranslationModel/ProbingPT/probing_hash_utils.hh new file mode 100644 index 000000000..964097829 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/probing_hash_utils.hh @@ -0,0 +1,37 @@ +#pragma once + +#include "util/probing_hash_table.hh" + +#include +#include +#include +#include + + +//Hash table entry +struct Entry { + uint64_t key; + typedef uint64_t Key; + unsigned int bytes_toread; + + uint64_t GetKey() const { + return key; + } + + void SetKey(uint64_t to) { + key = to; + } + + uint64_t GetValue() const { + return value; + } + + uint64_t value; +}; + +//Define table +typedef util::ProbingHashTable > Table; + +void serialize_table(char *mem, size_t size, const char * filename); + +char * readTable(const char * filename, size_t size); diff --git a/moses/TranslationModel/ProbingPT/quering.cpp b/moses/TranslationModel/ProbingPT/quering.cpp new file mode 100644 index 000000000..18efed917 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/quering.cpp @@ -0,0 +1,174 @@ +#include "quering.hh" + +unsigned char * read_binary_file(const char * filename, size_t filesize){ + //Get filesize + int fd; + unsigned char * map; + + fd = open(filename, O_RDONLY); + + if (fd == -1) { + perror("Error opening file for reading"); + exit(EXIT_FAILURE); + } + + map = (unsigned char *)mmap(0, filesize, PROT_READ, MAP_SHARED, fd, 0); + if (map == MAP_FAILED) { + close(fd); + perror("Error mmapping the file"); + exit(EXIT_FAILURE); + } + + return map; +} + +QueryEngine::QueryEngine(const char * filepath) : decoder(filepath){ + + //Create filepaths + std::string basepath(filepath); + std::string path_to_hashtable = basepath + "/probing_hash.dat"; + std::string path_to_data_bin = basepath + "/binfile.dat"; + std::string path_to_source_vocabid = basepath + "/source_vocabids"; + + ///Source phrase vocabids + read_map(&source_vocabids, path_to_source_vocabid.c_str()); + + //Target phrase vocabIDs + vocabids = decoder.get_target_lookup_map(); + + //Read config file + std::string line; + std::ifstream config ((basepath + "/config").c_str()); + getline(config, line); + int tablesize = atoi(line.c_str()); //Get tablesize. + config.close(); + + //Mmap binary table + struct stat filestatus; + stat(path_to_data_bin.c_str(), &filestatus); + binary_filesize = filestatus.st_size; + binary_mmaped = read_binary_file(path_to_data_bin.c_str(), binary_filesize); + + //Read hashtable + size_t table_filesize = Table::Size(tablesize, 1.2); + mem = readTable(path_to_hashtable.c_str(), table_filesize); + Table table_init(mem, table_filesize); + table = table_init; + + std::cerr << "Initialized successfully! " << std::endl; +} + +QueryEngine::~QueryEngine(){ + //Clear mmap content from memory. + munmap(binary_mmaped, binary_filesize); + munmap(mem, table_filesize); + +} + +std::pair > QueryEngine::query(std::vector source_phrase){ + bool found; + std::vector translation_entries; + const Entry * entry; + //TOO SLOW + //uint64_t key = util::MurmurHashNative(&source_phrase[0], source_phrase.size()); + uint64_t key = 0; + for (int i = 0; i < source_phrase.size(); i++){ + key += source_phrase[i]; + } + + + found = table.Find(key, entry); + + if (found){ + //The phrase that was searched for was found! We need to get the translation entries. + //We will read the largest entry in bytes and then filter the unnecesarry with functions + //from line_splitter + uint64_t initial_index = entry -> GetValue(); + unsigned int bytes_toread = entry -> bytes_toread; + + //ASK HIEU FOR MORE EFFICIENT WAY TO DO THIS! + std::vector encoded_text; //Assign to the vector the relevant portion of the array. + encoded_text.reserve(bytes_toread); + for (int i = 0; i < bytes_toread; i++){ + encoded_text.push_back(binary_mmaped[i+initial_index]); + } + + //Get only the translation entries necessary + translation_entries = decoder.full_decode_line(encoded_text); + + } + + std::pair > output (found, translation_entries); + + return output; + +} + +std::pair > QueryEngine::query(StringPiece source_phrase){ + bool found; + std::vector translation_entries; + const Entry * entry; + //Convert source frase to VID + std::vector source_phrase_vid = getVocabIDs(source_phrase); + //TOO SLOW + //uint64_t key = util::MurmurHashNative(&source_phrase_vid[0], source_phrase_vid.size()); + uint64_t key = 0; + for (int i = 0; i < source_phrase_vid.size(); i++){ + key += source_phrase_vid[i]; + } + + found = table.Find(key, entry); + + + if (found){ + //The phrase that was searched for was found! We need to get the translation entries. + //We will read the largest entry in bytes and then filter the unnecesarry with functions + //from line_splitter + uint64_t initial_index = entry -> GetValue(); + unsigned int bytes_toread = entry -> bytes_toread; + //At the end of the file we can't readd + largest_entry cause we get a segfault. + std::cerr << "Entry size is bytes is: " << bytes_toread << std::endl; + + //ASK HIEU FOR MORE EFFICIENT WAY TO DO THIS! + std::vector encoded_text; //Assign to the vector the relevant portion of the array. + encoded_text.reserve(bytes_toread); + for (int i = 0; i < bytes_toread; i++){ + encoded_text.push_back(binary_mmaped[i+initial_index]); + } + + //Get only the translation entries necessary + translation_entries = decoder.full_decode_line(encoded_text); + + } + + std::pair > output (found, translation_entries); + + return output; + +} + +void QueryEngine::printTargetInfo(std::vector target_phrases){ + int entries = target_phrases.size(); + + for (int i = 0; i //For finding size of file +#include "vocabid.hh" + + +char * read_binary_file(char * filename); + +class QueryEngine { + unsigned char * binary_mmaped; //The binari phrase table file + std::map vocabids; + std::map source_vocabids; + + Table table; + char *mem; //Memory for the table, necessary so that we can correctly destroy the object + + HuffmanDecoder decoder; + + size_t binary_filesize; + size_t table_filesize; + public: + QueryEngine (const char *); + ~QueryEngine(); + std::pair > query(StringPiece source_phrase); + std::pair > query(std::vector source_phrase); + void printTargetInfo(std::vector target_phrases); + const std::map getVocab() const + { return decoder.get_target_lookup_map(); } + + const std::map getSourceVocab() const { + return source_vocabids; + } + +}; + + diff --git a/moses/TranslationModel/ProbingPT/storing.cpp b/moses/TranslationModel/ProbingPT/storing.cpp new file mode 100644 index 000000000..5ea0df39c --- /dev/null +++ b/moses/TranslationModel/ProbingPT/storing.cpp @@ -0,0 +1,151 @@ +#include "storing.hh" + +BinaryFileWriter::BinaryFileWriter (std::string basepath) : os ((basepath + "/binfile.dat").c_str(), std::ios::binary) { + binfile.reserve(10000); //Reserve part of the vector to avoid realocation + it = binfile.begin(); + dist_from_start = 0; //Initialize variables + extra_counter = 0; +} + +void BinaryFileWriter::write (std::vector * bytes) { + binfile.insert(it, bytes->begin(), bytes->end()); //Insert the bytes + //Keep track of the offsets + it += bytes->size(); + dist_from_start = distance(binfile.begin(),it); + //Flush the vector to disk every once in a while so that we don't consume too much ram + if (dist_from_start > 9000) { + flush(); + } +} + +void BinaryFileWriter::flush () { + //Cast unsigned char to char before writing... + os.write((char *)&binfile[0], dist_from_start); + //Clear the vector: + binfile.clear(); + binfile.reserve(10000); + extra_counter += dist_from_start; //Keep track of the total number of bytes. + it = binfile.begin(); //Reset iterator + dist_from_start = distance(binfile.begin(),it); //Reset dist from start +} + +BinaryFileWriter::~BinaryFileWriter (){ + os.close(); + binfile.clear(); +} + +void createProbingPT(const char * phrasetable_path, const char * target_path){ + //Get basepath and create directory if missing + std::string basepath(target_path); + mkdir(basepath.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); + + //Set up huffman and serialize decoder maps. + Huffman huffmanEncoder(phrasetable_path); //initialize + huffmanEncoder.assign_values(); + huffmanEncoder.produce_lookups(); + huffmanEncoder.serialize_maps(target_path); + + //Get uniq lines: + unsigned long uniq_entries = huffmanEncoder.getUniqLines(); + + //Source phrase vocabids + std::map source_vocabids; + + //Read the file + util::FilePiece filein(phrasetable_path); + + //Init the probing hash table + size_t size = Table::Size(uniq_entries, 1.2); + char * mem = new char[size]; + memset(mem, 0, size); + Table table(mem, size); + + BinaryFileWriter binfile(basepath); //Init the binary file writer. + + line_text prev_line; //Check if the source phrase of the previous line is the same + + //Keep track of the size of each group of target phrases + uint64_t entrystartidx = 0; + //uint64_t line_num = 0; + + + //Read everything and processs + while(true){ + try { + //Process line read + line_text line; + line = splitLine(filein.ReadLine()); + //Add source phrases to vocabularyIDs + add_to_map(&source_vocabids, line.source_phrase); + + if ((binfile.dist_from_start + binfile.extra_counter) == 0) { + prev_line = line; //For the first iteration assume the previous line is + } //The same as this one. + + if (line.source_phrase != prev_line.source_phrase){ + + //Create a new entry even + + //Create an entry for the previous source phrase: + Entry pesho; + pesho.value = entrystartidx; + //The key is the sum of hashes of individual words. Probably not entirerly correct, but fast + pesho.key = 0; + std::vector vocabid_source = getVocabIDs(prev_line.source_phrase); + for (int i = 0; i < vocabid_source.size(); i++){ + pesho.key += vocabid_source[i]; + } + pesho.bytes_toread = binfile.dist_from_start + binfile.extra_counter - entrystartidx; + + //Put into table + table.Insert(pesho); + + entrystartidx = binfile.dist_from_start + binfile.extra_counter; //Designate start idx for new entry + + //Encode a line and write it to disk. + std::vector encoded_line = huffmanEncoder.full_encode_line(line); + binfile.write(&encoded_line); + + //Set prevLine + prev_line = line; + + } else{ + //If we still have the same line, just append to it: + std::vector encoded_line = huffmanEncoder.full_encode_line(line); + binfile.write(&encoded_line); + } + + } catch (util::EndOfFileException e){ + std::cerr << "Reading phrase table finished, writing remaining files to disk." << std::endl; + binfile.flush(); + + //After the final entry is constructed we need to add it to the phrase_table + //Create an entry for the previous source phrase: + Entry pesho; + pesho.value = entrystartidx; + //The key is the sum of hashes of individual words. Probably not entirerly correct, but fast + pesho.key = 0; + std::vector vocabid_source = getVocabIDs(prev_line.source_phrase); + for (int i = 0; i < vocabid_source.size(); i++){ + pesho.key += vocabid_source[i]; + } + pesho.bytes_toread = binfile.dist_from_start + binfile.extra_counter - entrystartidx; + //Put into table + table.Insert(pesho); + + break; + } + } + + serialize_table(mem, size, (basepath + "/probing_hash.dat").c_str()); + + serialize_map(&source_vocabids, (basepath + "/source_vocabids").c_str()); + + delete[] mem; + + //Write configfile + std::ofstream configfile; + configfile.open((basepath + "/config").c_str()); + configfile << uniq_entries << '\n'; + configfile.close(); +} diff --git a/moses/TranslationModel/ProbingPT/storing.hh b/moses/TranslationModel/ProbingPT/storing.hh new file mode 100644 index 000000000..dfcdbcc41 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/storing.hh @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +#include "hash.hh" //Includes line_splitter +#include "probing_hash_utils.hh" +#include "huffmanish.hh" +#include //mkdir + +#include "util/file_piece.hh" +#include "util/file.hh" +#include "vocabid.hh" + +void createProbingPT(const char * phrasetable_path, const char * target_path); + +class BinaryFileWriter { + std::vector binfile; + std::vector::iterator it; + //Output binary + std::ofstream os; + +public: + unsigned int dist_from_start; //Distance from the start of the vector. + uint64_t extra_counter; //After we reset the counter, we still want to keep track of the correct offset, so + + BinaryFileWriter (std::string); + ~BinaryFileWriter (); + void write (std::vector * bytes); + void flush (); //Flush to disk + +}; diff --git a/moses/TranslationModel/ProbingPT/tests/tokenization_tests.cpp b/moses/TranslationModel/ProbingPT/tests/tokenization_tests.cpp new file mode 100644 index 000000000..2a63242de --- /dev/null +++ b/moses/TranslationModel/ProbingPT/tests/tokenization_tests.cpp @@ -0,0 +1,198 @@ +#include "line_splitter.hh" + +bool test_vectorinsert() { + StringPiece line1 = StringPiece("! ! ! ! ||| ! ! ! ! ||| 0.0804289 0.141656 0.0804289 0.443409 2.718 ||| 0-0 1-1 2-2 3-3 ||| 1 1 1"); + StringPiece line2 = StringPiece("! ! ! ) , has ||| ! ! ! ) - , a ||| 0.0804289 0.0257627 0.0804289 0.00146736 2.718 ||| 0-0 1-1 2-2 3-3 4-4 4-5 5-6 ||| 1 1 1"); + line_text output = splitLine(line1); + line_text output2 = splitLine(line2); + + //Init container vector and iterator. + std::vector container; + container.reserve(10000); //Reserve vector + std::vector::iterator it = container.begin(); + std::pair::iterator, int> binary_append_ret; //Return values from vector_append + + //Put a value into the vector + binary_append_ret = vector_append(&output, &container, it, false); + it = binary_append_ret.first; + binary_append_ret = vector_append(&output2, &container, it, false); + it = binary_append_ret.first; + + std::string test(container.begin(), container.end()); + std::string should_be = "! ! ! ! 0.0804289 0.141656 0.0804289 0.443409 2.718 0-0 1-1 2-2 3-3 1 1 1! ! ! ) - , a 0.0804289 0.0257627 0.0804289 0.00146736 2.718 0-0 1-1 2-2 3-3 4-4 4-5 5-6 1 1 1"; + if (test == should_be) { + return true; + } else { + return false; + } +} + +bool probabilitiesTest(){ + StringPiece line1 = StringPiece("0.536553 0.75961 0.634108 0.532927 2.718"); + StringPiece line2 = StringPiece("1.42081e-05 3.91895e-09 0.0738539 0.749514 2.718"); + + std::vector pesho; + bool peshobool = false; + bool kirobool = false; + std::vector kiro; + + pesho = splitProbabilities(line1); + kiro = splitProbabilities(line2); + + if (pesho[0] == 0.536553 && pesho[1] == 0.75961 && pesho[2] == 0.634108 && pesho[3] == 0.532927 && pesho[4] == 2.718 && pesho.size() == 5) { + peshobool = true; + } else { + std::cout << "Processed: " << pesho[0] << " " << pesho[1] << " " << pesho[2] << " " << pesho[3] << " " << pesho[4] << std::endl; + std::cout << "Size is: " << pesho.size() << " Expected 5." << std::endl; + std::cout << "Expected: " << "0.536553 0.75961 0.634108 0.532927 2.718" << std::endl; + } + + if (kiro[0] == 1.42081e-05 && kiro[1] == 3.91895e-09 && kiro[2] == 0.0738539 && kiro[3] == 0.749514 && kiro[4] == 2.718 && kiro.size() == 5) { + kirobool = true; + } else { + std::cout << "Processed: " << kiro[0] << " " << kiro[1] << " " << kiro[2] << " " << kiro[3] << " " << kiro[4] << std::endl; + std::cout << "Size is: " << kiro.size() << " Expected 5." << std::endl; + std::cout << "Expected: " << "1.42081e-05 3.91895e-09 0.0738539 0.749514 2.718" << std::endl; + } + + return (peshobool && kirobool); +} + +bool wordAll1test(){ + StringPiece line1 = StringPiece("2-0 3-1 4-2 5-2"); + StringPiece line2 = StringPiece("0-0 1-1 2-2 3-3 4-3 6-4 5-5"); + + std::vector pesho; + bool peshobool = false; + bool kirobool = false; + std::vector kiro; + + pesho = splitWordAll1(line1); + kiro = splitWordAll1(line2); + + if (pesho[0] == 2 && pesho[1] == 0 && pesho[2] == 3 && pesho[3] == 1 && pesho[4] == 4 + && pesho[5] == 2 && pesho[6] == 5 && pesho[7] == 2 && pesho.size() == 8) { + peshobool = true; + } else { + std::cout << "Processed: " << pesho[0] << "-" << pesho[1] << " " << pesho[2] << "-" << pesho[3] << " " + << pesho[4] << "-" << pesho[5] << " " << pesho[6] << "-" << pesho[7] << std::endl; + std::cout << "Size is: " << pesho.size() << " Expected: 8." << std::endl; + std::cout << "Expected: " << "2-0 3-1 4-2 5-2" << std::endl; + } + + if (kiro[0] == 0 && kiro[1] == 0 && kiro[2] == 1 && kiro[3] == 1 && kiro[4] == 2 && kiro[5] == 2 + && kiro[6] == 3 && kiro[7] == 3 && kiro[8] == 4 && kiro[9] == 3 && kiro[10] == 6 && kiro[11] == 4 + && kiro[12] == 5 && kiro[13] == 5 && kiro.size() == 14){ + kirobool = true; + } else { + std::cout << "Processed: " << kiro[0] << "-" << kiro[1] << " " << kiro[2] << "-" << kiro[3] << " " + << kiro[4] << "-" << kiro[5] << " " << kiro[6] << "-" << kiro[7] << " " << kiro[8] << "-" << kiro[9] + << " " << kiro[10] << "-" << kiro[11] << " " << kiro[12] << "-" << kiro[13] << std::endl; + std::cout << "Size is: " << kiro.size() << " Expected: 14" << std::endl; + std::cout << "Expected: " << "0-0 1-1 2-2 3-3 4-3 6-4 5-5" << std::endl; + } + + return (peshobool && kirobool); +} + +bool wordAll2test(){ + StringPiece line1 = StringPiece("4 9 1"); + StringPiece line2 = StringPiece("3255 9 1"); + + std::vector pesho; + bool peshobool = false; + bool kirobool = false; + std::vector kiro; + + pesho = splitWordAll2(line1); + kiro = splitWordAll2(line2); + + if (pesho[0] == 4 && pesho[1] == 9 && pesho[2] == 1 && pesho.size() == 3){ + peshobool = true; + } else { + std::cout << "Processed: " << pesho[0] << " " << pesho[1] << " " << pesho[2] << std::endl; + std::cout << "Size: " << pesho.size() << " Expected: 3" << std::endl; + std::cout << "Expected: " << "4 9 1" << std::endl; + } + + if (kiro[0] == 3255 && kiro[1] == 9 && kiro[2] == 1 && kiro.size() == 3){ + kirobool = true; + } else { + std::cout << "Processed: " << kiro[0] << " " << kiro[1] << " " << kiro[2] << std::endl; + std::cout << "Size: " << kiro.size() << " Expected: 3" << std::endl; + std::cout << "Expected: " << "3255 9 1" << std::endl; + } + + return (peshobool && kirobool); + +} + +bool test_tokenization(){ + StringPiece line1 = StringPiece("! ! ! ! ||| ! ! ! ! ||| 0.0804289 0.141656 0.0804289 0.443409 2.718 ||| 0-0 1-1 2-2 3-3 ||| 1 1 1"); + StringPiece line2 = StringPiece("! ! ! ) , has ||| ! ! ! ) - , a ||| 0.0804289 0.0257627 0.0804289 0.00146736 2.718 ||| 0-0 1-1 2-2 3-3 4-4 4-5 5-6 ||| 1 1 1"); + StringPiece line3 = StringPiece("! ! ! ) , ||| ! ! ! ) - , ||| 0.0804289 0.075225 0.0804289 0.00310345 2.718 ||| 0-0 1-1 2-2 3-3 4-4 4-5 ||| 1 1 1"); + StringPiece line4 = StringPiece("! ! ! ) ||| ! ! ! ) . ||| 0.0804289 0.177547 0.0268096 0.000872597 2.718 ||| 0-0 1-1 2-2 3-3 ||| 1 3 1"); + + line_text output1 = splitLine(line1); + line_text output2 = splitLine(line2); + line_text output3 = splitLine(line3); + line_text output4 = splitLine(line4); + + bool test1 = output1.prob == StringPiece("0.0804289 0.141656 0.0804289 0.443409 2.718"); + bool test2 = output2.word_all1 == StringPiece("0-0 1-1 2-2 3-3 4-4 4-5 5-6"); + bool test3 = output2.target_phrase == StringPiece("! ! ! ) - , a"); + bool test4 = output3.source_phrase == StringPiece("! ! ! ) ,"); + bool test5 = output4.word_all2 == StringPiece("1 3 1"); + + //std::cout << test1 << " " << test2 << " " << test3 << " " << test4 << std::endl; + + return (test1 && test2 && test3 && test4 && test5); + +} + +bool test_linesplitter(){ + StringPiece line1 = StringPiece("! ] 0.0738539 0.901133 0.0738539 0.65207 2.718 0-0 1-1 1 1 1"); + target_text ans1; + ans1 = splitSingleTargetLine(line1); + + /* For testing purposes + std::cout << ans1.target_phrase[0] << " " < ans1; + std::vector ans2; + + ans1 = splitTargetLine(line1); + ans2 = splitTargetLine(line2); + + bool sizes = ans1.size() == 1 && ans2.size() == 4; + bool prob = ans1[0].prob[3] == 0.65207 && ans2[1].prob[1] == 0.00049839; + bool word_alls = ans2[0].word_all2[1] == 11 && ans2[3].word_all1[5] == 3; + + /* FOr testing + std::cout << ans1.size() << std::endl; + std::cout << ans2.size() << std::endl; + std::cout << ans1[0].prob[3] << std::endl; + std::cout << ans2[1].prob[1] << std::endl; + std::cout << ans2[0].word_all2[1] << std::endl; + std::cout << ans2[3].word_all1[5] << std::endl; */ + + return sizes && prob && word_alls; +} + +int main(){ + if (probabilitiesTest() && wordAll1test() && wordAll2test() && test_tokenization() && test_linesplitter() && test_linessplitter() && test_vectorinsert()){ + std::cout << "All tests pass!" << std::endl; + } else { + std::cout << "Failiure in some tests!" << std::endl; + } + + return 1; +} \ No newline at end of file diff --git a/moses/TranslationModel/ProbingPT/tests/vocabid_test.cpp b/moses/TranslationModel/ProbingPT/tests/vocabid_test.cpp new file mode 100644 index 000000000..bc82db74e --- /dev/null +++ b/moses/TranslationModel/ProbingPT/tests/vocabid_test.cpp @@ -0,0 +1,45 @@ +#include //Map for vocab ids + +#include "hash.hh" +#include "vocabid.hh" + +int main(int argc, char* argv[]){ + + //Create a map and serialize it + std::map vocabids; + StringPiece demotext = StringPiece("Demo text with 3 elements"); + add_to_map(&vocabids, demotext); + //Serialize map + serialize_map(&vocabids, "/tmp/testmap.bin"); + + //Read the map and test if the values are the same + std::map newmap; + read_map(&newmap, "/tmp/testmap.bin"); + + //Used hashes + uint64_t num1 = getHash(StringPiece("Demo")); + uint64_t num2 = getVocabID("text"); + uint64_t num3 = getHash(StringPiece("with")); + uint64_t num4 = getVocabID("3"); + uint64_t num5 = getHash(StringPiece("elements")); + uint64_t num6 = 0; + + //Tests + bool test1 = getStringFromID(&newmap, num1) == getStringFromID(&vocabids, num1); + bool test2 = getStringFromID(&newmap, num2) == getStringFromID(&vocabids, num2); + bool test3 = getStringFromID(&newmap, num3) == getStringFromID(&vocabids, num3); + bool test4 = getStringFromID(&newmap, num4) == getStringFromID(&vocabids, num4); + bool test5 = getStringFromID(&newmap, num5) == getStringFromID(&vocabids, num5); + bool test6 = getStringFromID(&newmap, num6) == getStringFromID(&vocabids, num6); + + + if (test1 && test2 && test3 && test4 && test5 && test6){ + std::cout << "Map was successfully written and read!" << std::endl; + } else { + std::cout << "Error! " << test1 << " " << test2 << " " << test3 << " " << test4 << " " << test5 << " " << test6 << std::endl; + } + + + return 1; + +} diff --git a/moses/TranslationModel/ProbingPT/vocabid.cpp b/moses/TranslationModel/ProbingPT/vocabid.cpp new file mode 100644 index 000000000..bcdbe78d0 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/vocabid.cpp @@ -0,0 +1,29 @@ +#include "vocabid.hh" + +void add_to_map(std::map *karta, StringPiece textin){ + //Tokenize + util::TokenIter it(textin, util::SingleCharacter(' ')); + + while(it){ + karta->insert(std::pair(getHash(*it), it->as_string())); + it++; + } +} + +void serialize_map(std::map *karta, const char* filename){ + std::ofstream os (filename, std::ios::binary); + boost::archive::text_oarchive oarch(os); + + oarch << *karta; //Serialise map + os.close(); +} + +void read_map(std::map *karta, const char* filename){ + std::ifstream is (filename, std::ios::binary); + boost::archive::text_iarchive iarch(is); + + iarch >> *karta; + + //Close the stream after we are done. + is.close(); +} diff --git a/moses/TranslationModel/ProbingPT/vocabid.hh b/moses/TranslationModel/ProbingPT/vocabid.hh new file mode 100644 index 000000000..491c53439 --- /dev/null +++ b/moses/TranslationModel/ProbingPT/vocabid.hh @@ -0,0 +1,20 @@ +//Serialization +#include +#include +#include +#include +#include +#include +#include + +#include //Container +#include "hash.hh" //Hash of elements + +#include "util/string_piece.hh" //Tokenization and work with StringPiece +#include "util/tokenize_piece.hh" + +void add_to_map(std::map *karta, StringPiece textin); + +void serialize_map(std::map *karta, const char* filename); + +void read_map(std::map *karta, const char* filename); diff --git a/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.cpp b/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.cpp index 627353097..581842494 100644 --- a/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.cpp +++ b/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.cpp @@ -36,6 +36,8 @@ namespace Moses { PhraseDictionaryOnDisk::PhraseDictionaryOnDisk(const std::string &line) : MyBase(line) + , m_maxSpanDefault(NOT_FOUND) + , m_maxSpanLabelled(NOT_FOUND) { ReadParameters(); } @@ -208,5 +210,19 @@ const TargetPhraseCollection *PhraseDictionaryOnDisk::GetTargetPhraseCollectionN return targetPhrases; } +void PhraseDictionaryOnDisk::SetParameter(const std::string& key, const std::string& value) +{ + if (key == "max-span-default") { + m_maxSpanDefault = Scan(value); + } + else if (key == "max-span-labelled") { + m_maxSpanLabelled = Scan(value); + } + else { + PhraseDictionary::SetParameter(key, value); + } +} + + } // namespace diff --git a/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.h b/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.h index f32b6ca1e..19548411c 100644 --- a/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.h +++ b/moses/TranslationModel/RuleTable/PhraseDictionaryOnDisk.h @@ -48,6 +48,7 @@ class PhraseDictionaryOnDisk : public PhraseDictionary { typedef PhraseDictionary MyBase; friend std::ostream& operator<<(std::ostream&, const PhraseDictionaryOnDisk&); + friend class ChartRuleLookupManagerOnDisk; protected: #ifdef WITH_THREADS @@ -56,6 +57,8 @@ protected: boost::scoped_ptr m_implementation; #endif + size_t m_maxSpanDefault, m_maxSpanLabelled; + OnDiskPt::OnDiskWrapper &GetImplementation(); const OnDiskPt::OnDiskWrapper &GetImplementation() const; @@ -82,6 +85,8 @@ public: const TargetPhraseCollection *GetTargetPhraseCollection(const OnDiskPt::PhraseNode *ptNode) const; const TargetPhraseCollection *GetTargetPhraseCollectionNonCache(const OnDiskPt::PhraseNode *ptNode) const; + void SetParameter(const std::string& key, const std::string& value); + }; } // namespace Moses diff --git a/moses/TranslationModel/UG/Jamfile b/moses/TranslationModel/UG/Jamfile index 1ee663044..ecd175a65 100644 --- a/moses/TranslationModel/UG/Jamfile +++ b/moses/TranslationModel/UG/Jamfile @@ -9,6 +9,17 @@ $(TOP)/moses/TranslationModel/UG//mmsapt $(TOP)/util//kenutil ; +exe ptable-lookup : +ptable-lookup.cc +$(TOP)/moses//moses +$(TOP)/moses/TranslationModel/UG/generic//generic +$(TOP)//boost_iostreams +$(TOP)//boost_program_options +$(TOP)/moses/TranslationModel/UG/mm//mm +$(TOP)/moses/TranslationModel/UG//mmsapt +$(TOP)/util//kenutil +; + install $(PREFIX)/bin : try-align ; fakelib mmsapt : [ glob *.cpp mmsapt*.cc ] ; diff --git a/moses/TranslationModel/UG/mm/custom-pt.cc b/moses/TranslationModel/UG/mm/custom-pt.cc index 9de67ff95..1c1e0893c 100644 --- a/moses/TranslationModel/UG/mm/custom-pt.cc +++ b/moses/TranslationModel/UG/mm/custom-pt.cc @@ -23,6 +23,7 @@ #include "ug_typedefs.h" #include "tpt_pickler.h" #include "ug_bitext.h" +#include "../mmsapt_phrase_scorers.h" #include "ug_lexical_phrase_scorer2.h" using namespace std; @@ -44,7 +45,7 @@ float lbsmooth = .005; PScorePfwd calc_pfwd; PScorePbwd calc_pbwd; -PScoreLex calc_lex; +PScoreLex calc_lex(1.0); PScoreWP apply_wp; vector fweights; @@ -129,8 +130,8 @@ int main(int argc, char* argv[]) bt.setDefaultSampleSize(max_samples); size_t i; - i = calc_pfwd.init(0,.05); - i = calc_pbwd.init(i,.05); + i = calc_pfwd.init(0,.05,'g'); + i = calc_pbwd.init(i,.05,'g'); i = calc_lex.init(i,base+L1+"-"+L2+".lex"); i = apply_wp.init(i); diff --git a/moses/TranslationModel/UG/mm/mmlex-lookup.cc b/moses/TranslationModel/UG/mm/mmlex-lookup.cc index 14d839edf..fbdceeaa0 100644 --- a/moses/TranslationModel/UG/mm/mmlex-lookup.cc +++ b/moses/TranslationModel/UG/mm/mmlex-lookup.cc @@ -131,7 +131,7 @@ interpret_args(int ac, char* av[]) o.add_options() ("help,h", "print this message") ("source,s",po::value(&swrd),"source word") - ("target,t",po::value(&swrd),"target word") + ("target,t",po::value(&twrd),"target word") ; h.add_options() diff --git a/moses/TranslationModel/UG/mm/ug_bitext.cc b/moses/TranslationModel/UG/mm/ug_bitext.cc index c4f5175f3..8dbbdcb92 100644 --- a/moses/TranslationModel/UG/mm/ug_bitext.cc +++ b/moses/TranslationModel/UG/mm/ug_bitext.cc @@ -255,9 +255,10 @@ namespace Moses float lbop(size_t const tries, size_t const succ, float const confidence) { - return - boost::math::binomial_distribution<>:: - find_lower_bound_on_p(tries, succ, confidence); + return (confidence == 0 + ? float(succ)/tries + : (boost::math::binomial_distribution<>:: + find_lower_bound_on_p(tries, succ, confidence))); } PhrasePair const& diff --git a/moses/TranslationModel/UG/mm/ug_bitext.h b/moses/TranslationModel/UG/mm/ug_bitext.h index 84c3713ac..397253973 100644 --- a/moses/TranslationModel/UG/mm/ug_bitext.h +++ b/moses/TranslationModel/UG/mm/ug_bitext.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "moses/TranslationModel/UG/generic/sorting/VectorIndexSorter.h" #include "moses/TranslationModel/UG/generic/sampling/Sampling.h" @@ -193,239 +194,6 @@ namespace Moses { float eval(vector const& w); }; - template - class - PhraseScorer - { - protected: - int index; - int num_feats; - public: - - virtual - void - operator()(Bitext const& pt, PhrasePair& pp, vector * dest) - const = 0; - - int - fcnt() const - { return num_feats; } - - int - getIndex() const - { return index; } - }; - - template - class - PScorePfwd : public PhraseScorer - { - float conf; - char denom; - public: - PScorePfwd() - { - this->num_feats = 1; - } - - int - init(int const i, float const c, char d=0) - { - conf = c; - denom = d; - this->index = i; - return i + this->num_feats; - } - - void - operator()(Bitext const& bt, - PhrasePair & pp, - vector * dest = NULL) const - { - if (!dest) dest = &pp.fvals; - if (pp.joint > pp.good1) - { - cerr<index] = log(lbop(pp.good1, pp.joint, conf)); - break; - case 's': - (*dest)[this->index] = log(lbop(pp.sample1, pp.joint, conf)); - break; - case 'r': - (*dest)[this->index] = log(lbop(pp.raw1, pp.joint, conf)); - } - } - }; - - template - class - PScorePbwd : public PhraseScorer - { - float conf; - public: - PScorePbwd() - { - this->num_feats = 1; - } - - int - init(int const i, float const c) - { - conf = c; - this->index = i; - return i + this->num_feats; - } - - void - operator()(Bitext const& bt, PhrasePair& pp, - vector * dest = NULL) const - { - if (!dest) dest = &pp.fvals; - (*dest)[this->index] = log(lbop(max(pp.raw2,pp.joint),pp.joint,conf)); - } - }; - - template - class - PScoreLogCounts : public PhraseScorer - { - float conf; - public: - PScoreLogCounts() - { - this->num_feats = 4; - } - - int - init(int const i) - { - this->index = i; - return i + this->num_feats; - } - - void - operator()(Bitext const& bt, PhrasePair& pp, - vector * dest = NULL) const - { - if (!dest) dest = &pp.fvals; - size_t i = this->index; - assert(pp.raw1); - assert(pp.sample1); - assert(pp.joint); - assert(pp.raw2); - (*dest)[i] = log(pp.raw1); - (*dest)[++i] = log(pp.sample1); - (*dest)[++i] = log(pp.joint); - (*dest)[++i] = log(pp.raw2); - } - }; - - template - class - PScoreLex : public PhraseScorer - { - public: - LexicalPhraseScorer2 scorer; - - PScoreLex() { this->num_feats = 2; } - - int - init(int const i, string const& fname) - { - scorer.open(fname); - this->index = i; - return i + this->num_feats; - } - - void - operator()(Bitext const& bt, PhrasePair& pp, vector * dest = NULL) const - { - if (!dest) dest = &pp.fvals; - uint32_t sid1=0,sid2=0,off1=0,off2=0,len1=0,len2=0; - parse_pid(pp.p1, sid1, off1, len1); - parse_pid(pp.p2, sid2, off2, len2); - -#if 0 - cout << len1 << " " << len2 << endl; - Token const* t1 = bt.T1->sntStart(sid1); - for (size_t i = off1; i < off1 + len1; ++i) - cout << (*bt.V1)[t1[i].id()] << " "; - cout << __FILE__ << ":" << __LINE__ << endl; - - Token const* t2 = bt.T2->sntStart(sid2); - for (size_t i = off2; i < off2 + len2; ++i) - cout << (*bt.V2)[t2[i].id()] << " "; - cout << __FILE__ << ":" << __LINE__ << endl; - - BOOST_FOREACH (int a, pp.aln) - cout << a << " " ; - cout << __FILE__ << ":" << __LINE__ << "\n" << endl; - -#endif - scorer.score(bt.T1->sntStart(sid1)+off1,0,len1, - bt.T2->sntStart(sid2)+off2,0,len2, - pp.aln, (*dest)[this->index], - (*dest)[this->index+1]); - } - - }; - - /// Word penalty - template - class - PScoreWP : public PhraseScorer - { - public: - - PScoreWP() { this->num_feats = 1; } - - int - init(int const i) - { - this->index = i; - return i + this->num_feats; - } - - void - operator()(Bitext const& bt, PhrasePair& pp, vector * dest = NULL) const - { - if (!dest) dest = &pp.fvals; - uint32_t sid2=0,off2=0,len2=0; - parse_pid(pp.p2, sid2, off2, len2); - (*dest)[this->index] = len2; - } - - }; - - /// Phrase penalty - template - class - PScorePP : public PhraseScorer - { - public: - - PScorePP() { this->num_feats = 1; } - - int - init(int const i) - { - this->index = i; - return i + this->num_feats; - } - - void - operator()(Bitext const& bt, PhrasePair& pp, vector * dest = NULL) const - { - if (!dest) dest = &pp.fvals; - (*dest)[this->index] = 1; - } - - }; template class Bitext @@ -590,8 +358,9 @@ namespace Moses { static ThreadSafeCounter active; boost::mutex lock; friend class agenda; - boost::taus88 rnd; // every job has its own pseudo random generator - double rnddenom; // denominator for scaling random sampling + boost::taus88 rnd; // every job has its own pseudo random generator + double rnddenom; // denominator for scaling random sampling + size_t min_diverse; // minimum number of distinct translations public: size_t workers; // how many workers are working on this job? sptr const> root; // root of the underlying suffix array @@ -644,34 +413,47 @@ namespace Moses { step(uint64_t & sid, uint64_t & offset) { boost::lock_guard jguard(lock); - if ((max_samples == 0) && (next < stop)) + bool ret = (max_samples == 0) && (next < stop); + if (ret) { next = root->readSid(next,stop,sid); next = root->readOffset(next,stop,offset); boost::lock_guard sguard(stats->lock); if (stats->raw_cnt == ctr) ++stats->raw_cnt; stats->sample_cnt++; - return true; } else { - while (next < stop && stats->good < max_samples) + while (next < stop && (stats->good < max_samples || + stats->trg.size() < min_diverse)) { next = root->readSid(next,stop,sid); next = root->readOffset(next,stop,offset); - { - boost::lock_guard sguard(stats->lock); + { // brackets required for lock scoping; see sguard immediately below + boost::lock_guard sguard(stats->lock); if (stats->raw_cnt == ctr) ++stats->raw_cnt; - size_t rnum = (stats->raw_cnt - ctr++)*(rnd()/(rnd.max()+1.)); + size_t scalefac = (stats->raw_cnt - ctr++); + size_t rnum = scalefac*(rnd()/(rnd.max()+1.)); +#if 0 + cerr << rnum << "/" << scalefac << " vs. " + << max_samples - stats->good << " (" + << max_samples << " - " << stats->good << ")" + << endl; +#endif if (rnum < max_samples - stats->good) { stats->sample_cnt++; - return true; + ret = true; + break; } } } - return false; } + + // boost::lock_guard sguard(stats->lock); + // abuse of lock for clean output to cerr + // cerr << stats->sample_cnt++; + return ret; } template @@ -713,6 +495,13 @@ namespace Moses { worker:: operator()() { + // things to do: + // - have each worker maintain their own pstats object and merge results at the end; + // - ensure the minimum size of samples considered by a non-locked counter that is only + // ever incremented -- who cares if we look at more samples than required, as long + // as we look at at least the minimum required + // This way, we can reduce the number of lock / unlock operations we need to do during + // sampling. size_t s1=0, s2=0, e1=0, e2=0; uint64_t sid=0, offset=0; // of the source phrase while(sptr j = ag.get_job()) @@ -812,6 +601,7 @@ namespace Moses { sptr > const& r, size_t maxsmpl, bool isfwd) : rnd(0) , rnddenom(rnd.max() + 1.) + , min_diverse(10) , workers(0) , root(r) , next(m.lower_bound(-1)) diff --git a/moses/TranslationModel/UG/mm/ug_lexical_phrase_scorer2.h b/moses/TranslationModel/UG/mm/ug_lexical_phrase_scorer2.h index 2d64705f7..558b5a7fa 100644 --- a/moses/TranslationModel/UG/mm/ug_lexical_phrase_scorer2.h +++ b/moses/TranslationModel/UG/mm/ug_lexical_phrase_scorer2.h @@ -2,6 +2,8 @@ // lexical phrase scorer, version 1 // written by Ulrich Germann +// Is the +1 in computing the lexical probabilities taken from the original phrase-scoring code? + #ifndef __ug_lexical_phrase_scorer_h #define __ug_lexical_phrase_scorer_h @@ -11,6 +13,7 @@ #include #include "tpt_pickler.h" #include "ug_mm_2d_table.h" +#include "util/exception.hh" using namespace std; namespace ugdiss { @@ -19,6 +22,7 @@ namespace ugdiss class LexicalPhraseScorer2 { + vector ftag; public: typedef mm2dTable table_t; table_t COOC; @@ -28,16 +32,18 @@ namespace ugdiss void score(TKN const* snt1, size_t const s1, size_t const e1, TKN const* snt2, size_t const s2, size_t const e2, - vector & aln, float & fwd_score, float& bwd_score) const; + vector const & aln, float const alpha, + float & fwd_score, float& bwd_score) const; void score(TKN const* snt1, size_t const s1, size_t const e1, TKN const* snt2, size_t const s2, size_t const e2, char const* const aln_start, char const* const aln_end, - float & fwd_score, float& bwd_score) const; + float const alpha, float & fwd_score, float& bwd_score) const; + // plup: permissive lookup - float plup_fwd(id_type const s,id_type const t) const; - float plup_bwd(id_type const s,id_type const t) const; + float plup_fwd(id_type const s,id_type const t, float const alpha) const; + float plup_bwd(id_type const s,id_type const t, float const alpha) const; // to be done: // - on-the-fly smoothing ? // - better (than permissive-lookup) treatment of unknown combinations @@ -59,7 +65,8 @@ namespace ugdiss LexicalPhraseScorer2:: score(TKN const* snt1, size_t const s1, size_t const e1, TKN const* snt2, size_t const s2, size_t const e2, - vector & aln, float & fwd_score, float& bwd_score) const + vector const & aln, float const alpha, + float & fwd_score, float& bwd_score) const { vector p1(e1,0), p2(e2,0); vector c1(e1,0), c2(e2,0); @@ -68,9 +75,9 @@ namespace ugdiss { i1 = aln[k]; i2 = aln[++k]; if (i1 < s1 || i1 >= e1 || i2 < s2 || i2 >= e2) continue; - p1[i1] += plup_fwd(snt1[i1].id(),snt2[i2].id()); + p1[i1] += plup_fwd(snt1[i1].id(),snt2[i2].id(),alpha); ++c1[i1]; - p2[i2] += plup_bwd(snt1[i1].id(),snt2[i2].id()); + p2[i2] += plup_bwd(snt1[i1].id(),snt2[i2].id(),alpha); ++c2[i2]; } fwd_score = 0; @@ -78,45 +85,46 @@ namespace ugdiss { if (c1[i] == 1) fwd_score += log(p1[i]); else if (c1[i]) fwd_score += log(p1[i])-log(c1[i]); - else fwd_score += log(plup_fwd(snt1[i].id(),0)); + else fwd_score += log(plup_fwd(snt1[i].id(),0,alpha)); } bwd_score = 0; for (size_t i = s2; i < e2; ++i) { if (c2[i] == 1) bwd_score += log(p2[i]); else if (c2[i]) bwd_score += log(p2[i])-log(c2[i]); - else bwd_score += log(plup_bwd(0,snt2[i].id())); + else bwd_score += log(plup_bwd(0,snt2[i].id(),alpha)); } } template float LexicalPhraseScorer2:: - plup_fwd(id_type const s, id_type const t) const + plup_fwd(id_type const s, id_type const t, float const alpha) const { if (COOC.m1(s) == 0 || COOC.m2(t) == 0) return 1.0; - // if (!COOC[s][t]) cout << s << " " << t << endl; - // assert(COOC[s][t]); - return float(COOC[s][t]+1)/(COOC.m1(s)+1); + UTIL_THROW_IF2(alpha < 0,"At " << __FILE__ << ":" << __LINE__ + << ": alpha parameter must be >= 0"); + return float(COOC[s][t]+alpha)/(COOC.m1(s)+alpha); } - + template float LexicalPhraseScorer2:: - plup_bwd(id_type const s, id_type const t) const + plup_bwd(id_type const s, id_type const t,float const alpha) const { if (COOC.m1(s) == 0 || COOC.m2(t) == 0) return 1.0; - // assert(COOC[s][t]); - return float(COOC[s][t]+1)/(COOC.m2(t)+1); + UTIL_THROW_IF2(alpha < 0,"At " << __FILE__ << ":" << __LINE__ + << ": alpha parameter must be >= 0"); + return float(COOC[s][t]+alpha)/(COOC.m2(t)+alpha); } - + template void LexicalPhraseScorer2:: score(TKN const* snt1, size_t const s1, size_t const e1, TKN const* snt2, size_t const s2, size_t const e2, char const* const aln_start, char const* const aln_end, - float & fwd_score, float& bwd_score) const + float const alpha, float & fwd_score, float& bwd_score) const { vector p1(e1,0), p2(e2,0); vector c1(e1,0), c2(e2,0); @@ -125,9 +133,9 @@ namespace ugdiss { x = binread(binread(x,i1),i2); if (i1 < s1 || i1 >= e1 || i2 < s2 || i2 >= e2) continue; - p1[i1] += plup_fwd(snt1[i1].id(), snt2[i2].id()); + p1[i1] += plup_fwd(snt1[i1].id(), snt2[i2].id(),alpha); ++c1[i1]; - p2[i2] += plup_bwd(snt1[i1].id(), snt2[i2].id()); + p2[i2] += plup_bwd(snt1[i1].id(), snt2[i2].id(),alpha); ++c2[i2]; } fwd_score = 0; @@ -135,14 +143,14 @@ namespace ugdiss { if (c1[i] == 1) fwd_score += log(p1[i]); else if (c1[i]) fwd_score += log(p1[i])-log(c1[i]); - else fwd_score += log(plup_fwd(snt1[i].id(),0)); + else fwd_score += log(plup_fwd(snt1[i].id(),0,alpha)); } bwd_score = 0; for (size_t i = s2; i < e2; ++i) { if (c2[i] == 1) bwd_score += log(p2[i]); else if (c2[i]) bwd_score += log(p2[i])-log(c2[i]); - else bwd_score += log(plup_bwd(0,snt2[i].id())); + else bwd_score += log(plup_bwd(0,snt2[i].id(),alpha)); } } } diff --git a/moses/TranslationModel/UG/mmsapt.cpp b/moses/TranslationModel/UG/mmsapt.cpp index 128dcfe80..dc9945472 100644 --- a/moses/TranslationModel/UG/mmsapt.cpp +++ b/moses/TranslationModel/UG/mmsapt.cpp @@ -47,15 +47,25 @@ namespace Moses } #endif + vector const& + Mmsapt:: + GetFeatureNames() const + { + return m_feature_names; + } + Mmsapt:: Mmsapt(string const& line) - // : PhraseDictionary("Mmsapt",line), ofactor(1,0) : PhraseDictionary(line) + , m_lex_alpha(1.0) , withLogCountFeatures(false) - , withPfwd(true), withPbwd(true) + , withCoherence(true) + , m_pfwd_features("g") + , m_pbwd_features("g") + , withPbwd(true) + , poolCounts(true) , ofactor(1,0) , m_tpc_ctr(0) - // default values chosen for bwd probability { this->init(line); } @@ -101,52 +111,59 @@ namespace Moses assert(L1.size()); assert(L2.size()); - m = param.find("pfwd_denom"); + m = param.find("pfwd-denom"); m_pfwd_denom = m != param.end() ? m->second[0] : 's'; - + m = param.find("smooth"); m_lbop_parameter = m != param.end() ? atof(m->second.c_str()) : .05; m = param.find("max-samples"); m_default_sample_size = m != param.end() ? atoi(m->second.c_str()) : 1000; - m = param.find("logcnt-features"); - if (m != param.end()) + if ((m = param.find("logcnt-features")) != param.end()) withLogCountFeatures = m->second != "0"; - m = param.find("pfwd"); - if (m != param.end()) - withPfwd = m->second != "0"; + if ((m = param.find("coh")) != param.end()) + withCoherence = m->second != "0"; + + if ((m = param.find("pfwd")) != param.end()) + m_pfwd_features = (m->second == "0" ? "" : m->second); - m = param.find("pbwd"); - if (m != param.end()) - withPbwd = m->second != "0"; - - m_default_sample_size = m != param.end() ? atoi(m->second.c_str()) : 1000; + if (m_pfwd_features == "1") // legacy; deprecated + m_pfwd_features[0] = m_pfwd_denom; + + if ((m = param.find("pbwd")) != param.end()) + m_pbwd_features = (m->second == "0" ? "" : m->second); + + if (m_pbwd_features == "1") + m_pbwd_features = "r"; // lecagy; deprecated + + if ((m = param.find("lexalpha")) != param.end()) + m_lex_alpha = atof(m->second.c_str()); m = param.find("workers"); m_workers = m != param.end() ? atoi(m->second.c_str()) : 8; m_workers = min(m_workers,24UL); + if ((m = param.find("limit")) != param.end()) + m_tableLimit = atoi(m->second.c_str()); + m = param.find("cache-size"); - m_history.reserve(m != param.end() - ? max(1000,atoi(m->second.c_str())) - : 10000); + m_history.reserve(m != param.end()?max(1000,atoi(m->second.c_str())):10000); + // in plain language: cache size is at least 1000, and 10,000 by default + // this cache keeps track of the most frequently used target phrase collections + // even when not actively in use this->m_numScoreComponents = atoi(param["num-features"].c_str()); - // num_features = 0; m = param.find("ifactor"); input_factor = m != param.end() ? atoi(m->second.c_str()) : 0; + poolCounts = true; - m = param.find("extra"); - if (m != param.end()) - { - extra_data = m->second; - // cerr << "have extra data" << endl; - } - // keeps track of the most frequently used target phrase collections - // (to keep them cached even when not actively in use) + + if ((m = param.find("extra")) != param.end()) + extra_data = m->second; + } void @@ -175,6 +192,63 @@ namespace Moses // cerr << "Loaded " << btdyn->T1->size() << " sentence pairs" << endl; } + size_t + Mmsapt:: + add_corpus_specific_features + (vector >& ffvec, size_t num_feats) + { + float const lbop = m_lbop_parameter; // just for code readability below + // for the time being, we assume that all phrase probability features + // use the same confidence parameter for lower-bound-estimation + for (size_t i = 0; i < m_pfwd_features.size(); ++i) + { + UTIL_THROW_IF2(m_pfwd_features[i] != 'g' && + m_pfwd_features[i] != 'r' && + m_pfwd_features[i] != 's', + "Can't handle pfwd feature type '" + << m_pfwd_features[i] << "'."); + sptr > ff(new PScorePfwd()); + size_t k = num_feats; + num_feats = ff->init(num_feats,lbop,m_pfwd_features[i]); + for (;k < num_feats; ++k) m_feature_names.push_back(ff->fname(k)); + ffvec.push_back(ff); + } + + for (size_t i = 0; i < m_pbwd_features.size(); ++i) + { + UTIL_THROW_IF2(m_pbwd_features[i] != 'g' && + m_pbwd_features[i] != 'r' && + m_pbwd_features[i] != 's', + "Can't handle pbwd feature type '" + << m_pbwd_features[i] << "'."); + sptr > ff(new PScorePbwd()); + size_t k = num_feats; + num_feats = ff->init(num_feats,lbop,m_pbwd_features[i]); + for (;k < num_feats; ++k) m_feature_names.push_back(ff->fname(k)); + ffvec.push_back(ff); + } + + // if (withPbwd) + // { + // sptr > ff(new PScorePbwd()); + // size_t k = num_feats; + // num_feats = ff->init(num_feats,lbop); + // for (; k < num_feats; ++k) m_feature_names.push_back(ff->fname(k)); + // ffvec.push_back(ff); + // } + + if (withLogCountFeatures) + { + sptr > ff(new PScoreLogCounts()); + size_t k = num_feats; + num_feats = ff->init(num_feats); + for (; k < num_feats; ++k) m_feature_names.push_back(ff->fname(k)); + ffvec.push_back(ff); + } + + return num_feats; + } + void Mmsapt:: Load() @@ -184,44 +258,52 @@ namespace Moses btfix.setDefaultSampleSize(m_default_sample_size); size_t num_feats = 0; - // TO DO: should we use different lbop parameters - // for the relative-frequency based features? - if (withLogCountFeatures) num_feats = add_logcounts_fix.init(num_feats); - - float const lbop = m_lbop_parameter; // just for code readability below - if (withPfwd) num_feats = calc_pfwd_fix.init(num_feats,lbop,m_pfwd_denom); - if (withPbwd) num_feats = calc_pbwd_fix.init(num_feats,lbop); + // lexical scores are currently always active + sptr > ff(new PScoreLex(m_lex_alpha)); + size_t k = num_feats; + num_feats = ff->init(num_feats, bname + L1 + "-" + L2 + ".lex"); + for (; k < num_feats; ++k) m_feature_names.push_back(ff->fname(k)); + m_active_ff_common.push_back(ff); - // currently always active by default; may (should) change later - num_feats = calc_lex.init(num_feats, bname + L1 + "-" + L2 + ".lex"); - - if (this->m_numScoreComponents%2) // a bit of a hack, for backwards compatibility - num_feats = apply_pp.init(num_feats); - - if (num_feats < this->m_numScoreComponents) + if (withCoherence) { - poolCounts = false; - if (withLogCountFeatures) num_feats = add_logcounts_dyn.init(num_feats); - if (withPfwd) num_feats = calc_pfwd_dyn.init(num_feats,lbop,m_pfwd_denom); - if (withPbwd) num_feats = calc_pbwd_dyn.init(num_feats,lbop); + sptr > ff(new PScoreCoherence()); + size_t k = num_feats; + num_feats = ff->init(num_feats); + for (; k < num_feats; ++k) m_feature_names.push_back(ff->fname(k)); + m_active_ff_common.push_back(ff); } - - if (num_feats != this->m_numScoreComponents) - { - ostringstream buf; - buf << "At " << __FILE__ << ":" << __LINE__ - << ": number of feature values provided by Phrase table" - << " does not match number specified in Moses config file!"; - throw buf.str().c_str(); - } - // cerr << "MMSAPT provides " << num_feats << " features at " - // << __FILE__ << ":" << __LINE__ << endl; + num_feats = add_corpus_specific_features(m_active_ff_fix,num_feats); + // cerr << num_feats << "/" << this->m_numScoreComponents + // << " at " << __FILE__ << ":" << __LINE__ << endl; + poolCounts = poolCounts && num_feats == this->m_numScoreComponents; + if (!poolCounts) + num_feats = add_corpus_specific_features(m_active_ff_dyn, num_feats); + +#if 0 + cerr << "MMSAPT provides " << num_feats << " features at " + << __FILE__ << ":" << __LINE__ << endl; + BOOST_FOREACH(string const& fname, m_feature_names) + cerr << fname << endl; +#endif + UTIL_THROW_IF2(num_feats != this->m_numScoreComponents, + "At " << __FILE__ << ":" << __LINE__ + << ": number of feature values provided by Phrase table (" + << num_feats << ") does not match number specified in " + << "Moses config file (" << this->m_numScoreComponents + << ")!\n";); + + btdyn.reset(new imBitext(btfix.V1, btfix.V2,m_default_sample_size)); btdyn->num_workers = this->m_workers; - if (extra_data.size()) load_extra_data(extra_data); - + if (extra_data.size()) + { + load_extra_data(extra_data); + } + +#if 0 // currently not used LexicalPhraseScorer2::table_t & COOC = calc_lex.scorer.COOC; typedef LexicalPhraseScorer2::table_t::Cell cell_t; @@ -230,7 +312,8 @@ namespace Moses for (cell_t const* c = COOC[r].start; c < COOC[r].stop; ++c) wlex21[c->id].push_back(r); COOCraw.open(bname + L1 + "-" + L2 + ".coc"); - +#endif + } void @@ -283,20 +366,28 @@ namespace Moses { PhrasePair pp; pp.init(pid1, stats, this->m_numScoreComponents); - if (this->m_numScoreComponents%2) - apply_pp(bt,pp); pstats::trg_map_t::const_iterator t; for (t = stats.trg.begin(); t != stats.trg.end(); ++t) { pp.update(t->first,t->second); - calc_lex(bt,pp); - if (withPfwd) calc_pfwd_fix(bt,pp); - if (withPbwd) calc_pbwd_fix(bt,pp); - if (withLogCountFeatures) add_logcounts_fix(bt,pp); + BOOST_FOREACH(sptr const& ff, m_active_ff_fix) + (*ff)(bt,pp); + BOOST_FOREACH(sptr const& ff, m_active_ff_common) + (*ff)(bt,pp); tpcoll->Add(createTargetPhrase(src,bt,pp)); } } + void + Mmsapt:: + ScorePPfix(bitext::PhrasePair& pp) const + { + BOOST_FOREACH(sptr const& ff, m_active_ff_fix) + (*ff)(btfix,pp); + BOOST_FOREACH(sptr const& ff, m_active_ff_common) + (*ff)(btfix,pp); + } + // process phrase stats from a single parallel corpus bool Mmsapt:: @@ -318,8 +409,6 @@ namespace Moses pp.init(pid1b, *statsb, this->m_numScoreComponents); else return false; // throw "no stats for pooling available!"; - if (this->m_numScoreComponents%2) - apply_pp(bta,pp); pstats::trg_map_t::const_iterator b; pstats::trg_map_t::iterator a; if (statsb) @@ -344,10 +433,10 @@ namespace Moses b->second); } else pp.update(b->first,b->second); - calc_lex(btb,pp); - if (withPfwd) calc_pfwd_fix(btb,pp); - if (withPbwd) calc_pbwd_fix(btb,pp); - if (withLogCountFeatures) add_logcounts_fix(btb,pp); + BOOST_FOREACH(sptr const& ff, m_active_ff_fix) + (*ff)(btb,pp); + BOOST_FOREACH(sptr const& ff, m_active_ff_common) + (*ff)(btb,pp); tpcoll->Add(createTargetPhrase(src,btb,pp)); } } @@ -368,28 +457,28 @@ namespace Moses } else pp.update(a->first,a->second); +#if 0 + // jstats const& j = a->second; + cerr << bta.T1->pid2str(bta.V1.get(),pp.p1) << " ::: " + << bta.T2->pid2str(bta.V2.get(),pp.p2) << endl; + cerr << pp.raw1 << " " << pp.sample1 << " " << pp.good1 << " " + << pp.joint << " " << pp.raw2 << endl; +#endif UTIL_THROW_IF2(pp.raw2 == 0, - "OOPS" - << bta.T1->pid2str(bta.V1.get(),pp.p1) << " ::: " + "OOPS" << bta.T1->pid2str(bta.V1.get(),pp.p1) << " ::: " << bta.T2->pid2str(bta.V2.get(),pp.p2) << ": " << pp.raw1 << " " << pp.sample1 << " " << pp.good1 << " " << pp.joint << " " << pp.raw2); -#if 0 - jstats const& j = a->second; - cerr << bta.T1->pid2str(bta.V1.get(),pp.p1) << " ::: " - << bta.T2->pid2str(bta.V2.get(),pp.p2) << endl; - cerr << j.rcnt() << " " << j.cnt2() << " " << j.wcnt() << endl; -#endif - calc_lex(bta,pp); - if (withPfwd) calc_pfwd_fix(bta,pp); - if (withPbwd) calc_pbwd_fix(bta,pp); - if (withLogCountFeatures) add_logcounts_fix(bta,pp); + BOOST_FOREACH(sptr const& ff, m_active_ff_fix) + (*ff)(bta,pp); + BOOST_FOREACH(sptr const& ff, m_active_ff_common) + (*ff)(bta,pp); tpcoll->Add(createTargetPhrase(src,bta,pp)); } return true; -} + } // process phrase stats from a single parallel corpus @@ -397,75 +486,81 @@ namespace Moses Mmsapt:: combine_pstats (Phrase const& src, - uint64_t const pid1a, - pstats * statsa, - Bitext const & bta, - uint64_t const pid1b, - pstats const* statsb, - Bitext const & btb, - TargetPhraseCollection* tpcoll - ) const + uint64_t const pid1a, pstats * statsa, Bitext const & bta, + uint64_t const pid1b, pstats const* statsb, Bitext const & btb, + TargetPhraseCollection* tpcoll) const { PhrasePair ppfix,ppdyn,pool; + // ppfix: counts from btfix + // ppdyn: counts from btdyn + // pool: pooled counts from both Word w; if (statsa) ppfix.init(pid1a,*statsa,this->m_numScoreComponents); if (statsb) ppdyn.init(pid1b,*statsb,this->m_numScoreComponents); pstats::trg_map_t::const_iterator b; pstats::trg_map_t::iterator a; + if (statsb) { pool.init(pid1b,*statsb,0); - if (this->m_numScoreComponents%2) - apply_pp(btb,ppdyn); for (b = statsb->trg.begin(); b != statsb->trg.end(); ++b) { ppdyn.update(b->first,b->second); - if (withPfwd) calc_pfwd_dyn(btb,ppdyn); - if (withPbwd) calc_pbwd_dyn(btb,ppdyn); - if (withLogCountFeatures) add_logcounts_dyn(btb,ppdyn); - calc_lex(btb,ppdyn); + BOOST_FOREACH(sptr const& ff, m_active_ff_dyn) + (*ff)(btb,ppdyn); uint32_t sid,off,len; parse_pid(b->first, sid, off, len); Token const* x = bta.T2->sntStart(sid) + off; TSA::tree_iterator m(bta.I2.get(),x,x+len); + if (m.size() && statsa && - ((a = statsa->trg.find(m.getPid())) - != statsa->trg.end())) + ((a = statsa->trg.find(m.getPid())) != statsa->trg.end())) { + // phrase pair found also in btfix ppfix.update(a->first,a->second); - if (withPfwd) calc_pfwd_fix(bta,ppfix,&ppdyn.fvals); - if (withPbwd) calc_pbwd_fix(bta,ppfix,&ppdyn.fvals); - if (withLogCountFeatures) add_logcounts_fix(bta,ppfix,&ppdyn.fvals); + BOOST_FOREACH(sptr const& ff, m_active_ff_fix) + (*ff)(bta,ppfix,&ppdyn.fvals); + BOOST_FOREACH(sptr const& ff, m_active_ff_common) + (*ff)(bta,ppfix,&ppdyn.fvals); a->second.invalidate(); } else { - if (m.size()) - pool.update(b->first,m.approxOccurrenceCount(), - b->second); - else + // phrase pair was not found in btfix + + // ... but the source phrase was + if (m.size()) + pool.update(b->first,m.approxOccurrenceCount(), b->second); + + // ... and not even the source phrase + else pool.update(b->first,b->second); - if (withPfwd) calc_pfwd_fix(btb,pool,&ppdyn.fvals); - if (withPbwd) calc_pbwd_fix(btb,pool,&ppdyn.fvals); - if (withLogCountFeatures) add_logcounts_fix(btb,pool,&ppdyn.fvals); + + BOOST_FOREACH(sptr const& ff, m_active_ff_fix) + (*ff)(btb,pool,&ppdyn.fvals); + BOOST_FOREACH(sptr const& ff, m_active_ff_common) + (*ff)(btb,pool,&ppdyn.fvals); + } + tpcoll->Add(createTargetPhrase(src,btb,ppdyn)); } } + + // now deal with all phraise pairs that are ONLY in btfix + // (the ones that are in both were dealt with above) if (statsa) { pool.init(pid1a,*statsa,0); - if (this->m_numScoreComponents%2) - apply_pp(bta,ppfix); for (a = statsa->trg.begin(); a != statsa->trg.end(); ++a) { if (!a->second.valid()) continue; // done above ppfix.update(a->first,a->second); - if (withPfwd) calc_pfwd_fix(bta,ppfix); - if (withPbwd) calc_pbwd_fix(bta,ppfix); - if (withLogCountFeatures) add_logcounts_fix(bta,ppfix); - calc_lex(bta,ppfix); + BOOST_FOREACH(sptr const& ff, m_active_ff_fix) + (*ff)(bta,ppfix); + BOOST_FOREACH(sptr const& ff, m_active_ff_common) + (*ff)(bta,ppfix); if (btb.I2) { @@ -479,102 +574,15 @@ namespace Moses pool.update(a->first,a->second); } else pool.update(a->first,a->second); - if (withPfwd) calc_pfwd_dyn(bta,pool,&ppfix.fvals); - if (withPbwd) calc_pbwd_dyn(bta,pool,&ppfix.fvals); - if (withLogCountFeatures) add_logcounts_dyn(bta,pool,&ppfix.fvals); + BOOST_FOREACH(sptr const& ff, m_active_ff_dyn) + (*ff)(btb,pool,&ppfix.fvals); + if (ppfix.p2) + tpcoll->Add(createTargetPhrase(src,bta,ppfix)); } - if (ppfix.p2) - tpcoll->Add(createTargetPhrase(src,bta,ppfix)); } return (statsa || statsb); } - // // phrase statistics combination treating the two knowledge - // // sources separately with backoff to pooling when only one - // // of the two knowledge sources contains the phrase pair in - // // question - // void - // Mmsapt:: - // process_pstats(uint64_t const mypid1, - // uint64_t const otpid1, - // pstats const& mystats, // my phrase stats - // pstats const* otstats, // other phrase stats - // Bitext const & mybt, // my bitext - // Bitext const * otbt, // other bitext - // PhraseScorer const& mypfwd, - // PhraseScorer const& mypbwd, - // PhraseScorer const* otpfwd, - // PhraseScorer const* otpbwd, - // TargetPhraseCollection* tpcoll) - // { - // boost::unordered_map::const_iterator t; - // vector ofact(1,0); - // PhrasePair mypp,otpp,combo; - // mypp.init(mypid1, mystats, this->m_numScoreComponents); - // if (otstats) - // { - // otpp.init(otpid1, *otstats, 0); - // combo.init(otpid1, mystats, *otstats, 0); - // } - // else combo = mypp; - - // for (t = mystats.trg.begin(); t != mystats.trg.end(); ++t) - // { - // if (!t->second.valid()) continue; - // // we dealt with this phrase pair already; - // // see j->second.invalidate() below; - // uint32_t sid,off,len; parse_pid(t->first,sid,off,len); - - // mypp.update(t->first,t->second); - // apply_pp(mybt,mypp); - // calc_lex (mybt,mypp); - // mypfwd(mybt,mypp); - // mypbwd(mybt,mypp); - - // if (otbt) // it's a dynamic phrase table - // { - // assert(otpfwd); - // assert(otpbwd); - // boost::unordered_map::iterator j; - - // // look up the current target phrase in the other bitext - // Token const* x = mybt.T2->sntStart(sid) + off; - // TSA::tree_iterator m(otbt->I2.get(),x,x+len); - // if (otstats // source phrase exists in other bitext - // && m.size() // target phrase exists in other bitext - // && ((j = otstats->trg.find(m.getPid())) - // != otstats->trg.end())) // phrase pair found in other bitext - // { - // otpp.update(j->first,j->second); - // j->second.invalidate(); // mark the phrase pair as seen - // otpfwd(*otbt,otpp,&mypp.fvals); - // otpbwd(*otbt,otpp,&mypp.fvals); - // } - // else - // { - // if (m.size()) // target phrase seen in other bitext, but not the phrase pair - // combo.update(t->first,m.approxOccurrenceCount(),t->second); - // else - // combo.update(t->first,t->second); - // (*otpfwd)(mybt,combo,&mypp.fvals); - // (*otpbwd)(mybt,combo,&mypp.fvals); - // } - // } - - // // now add the phrase pair to the TargetPhraseCollection: - // TargetPhrase* tp = new TargetPhrase(); - // for (size_t k = off; k < stop; ++k) - // { - // StringPiece wrd = (*mybt.V2)[x[k].id()]; - // Word w; w.CreateFromString(Output,ofact,wrd,false); - // tp->AddWord(w); - // } - // tp->GetScoreBreakdown().Assign(this,mypp.fvals); - // tp->Evaluate(src); - // tpcoll->Add(tp); - // } - // } - Mmsapt:: TargetPhraseCollectionWrapper:: TargetPhraseCollectionWrapper(size_t r, uint64_t k) @@ -662,7 +670,7 @@ namespace Moses || combine_pstats(src, mfix.getPid(),sfix.get(),btfix, mdyn.getPid(),sdyn.get(),*dyn,ret)) { - ret->NthElement(m_tableLimit); + if (m_tableLimit) ret->Prune(true,m_tableLimit); #if 0 sort(ret->begin(), ret->end(), CompareTargetPhrase()); cout << "SOURCE PHRASE: " << src << endl; @@ -683,6 +691,14 @@ namespace Moses return encache(ret); } + size_t + Mmsapt:: + SetTableLimit(size_t limit) + { + std::swap(m_tableLimit,limit); + return limit; + } + void Mmsapt:: CleanUpAfterSentenceProcessing(const InputType& source) diff --git a/moses/TranslationModel/UG/mmsapt.h b/moses/TranslationModel/UG/mmsapt.h index 5353a1c46..b6be36131 100644 --- a/moses/TranslationModel/UG/mmsapt.h +++ b/moses/TranslationModel/UG/mmsapt.h @@ -29,6 +29,7 @@ #include #include "moses/TranslationModel/PhraseDictionary.h" +#include "mmsapt_phrase_scorers.h" // TO DO: // - make lexical phrase scorer take addition to the "dynamic overlay" into account @@ -51,6 +52,7 @@ namespace Moses typedef mmBitext mmbitext; typedef imBitext imbitext; typedef TSA tsa; + typedef PhraseScorer pscorer; private: mmbitext btfix; sptr btdyn; @@ -58,30 +60,49 @@ namespace Moses string L1; string L2; float m_lbop_parameter; + float m_lex_alpha; + // alpha parameter for lexical smoothing (joint+alpha)/(marg + alpha) + // must be > 0 if dynamic size_t m_default_sample_size; size_t m_workers; // number of worker threads for sampling the bitexts + + // deprecated! char m_pfwd_denom; // denominator for computation of fwd phrase score: // 'r' - divide by raw count // 's' - divide by sample count // 'g' - devide by number of "good" (i.e. coherent) samples // size_t num_features; + size_t input_factor; size_t output_factor; // we can actually return entire Tokens! + + bool withLogCountFeatures; // add logs of counts as features? + bool withCoherence; + string m_pfwd_features; // which pfwd functions to use + string m_pbwd_features; // which pbwd functions to use + vector m_feature_names; // names of features activated + vector > m_active_ff_fix; // activated feature functions (fix) + vector > m_active_ff_dyn; // activated feature functions (dyn) + vector > m_active_ff_common; // activated feature functions (dyn) + + size_t + add_corpus_specific_features + (vector >& ffvec, size_t num_feats); + // built-in feature functions - PScorePfwd calc_pfwd_fix, calc_pfwd_dyn; - PScorePbwd calc_pbwd_fix, calc_pbwd_dyn; - PScoreLex calc_lex; // this one I'd like to see as an external ff eventually - PScorePP apply_pp; // apply phrase penalty - PScoreLogCounts add_logcounts_fix; - PScoreLogCounts add_logcounts_dyn; + // PScorePfwd calc_pfwd_fix, calc_pfwd_dyn; + // PScorePbwd calc_pbwd_fix, calc_pbwd_dyn; + // PScoreLex calc_lex; // this one I'd like to see as an external ff eventually + // PScorePP apply_pp; // apply phrase penalty + // PScoreLogCounts add_logcounts_fix; + // PScoreLogCounts add_logcounts_dyn; void init(string const& line); mutable boost::mutex lock; + bool withPbwd; bool poolCounts; - bool withLogCountFeatures; // add logs of counts as features? - bool withPfwd,withPbwd; vector ofactor; - + public: // typedef boost::unordered_map > tpcoll_cache_t; class TargetPhraseCollectionWrapper @@ -168,6 +189,9 @@ namespace Moses void Load(); + // returns the prior table limit + size_t SetTableLimit(size_t limit); + #ifndef NO_MOSES TargetPhraseCollection const* GetTargetPhraseCollectionLEGACY(const Phrase& src) const; @@ -204,6 +228,12 @@ namespace Moses bool PrefixExists(Phrase const& phrase) const; + vector const& + GetFeatureNames() const; + + void + ScorePPfix(bitext::PhrasePair& pp) const; + private: }; } // end namespace diff --git a/moses/TranslationModel/UG/mmsapt_align.cc b/moses/TranslationModel/UG/mmsapt_align.cc index 4dd6081b0..407df648d 100644 --- a/moses/TranslationModel/UG/mmsapt_align.cc +++ b/moses/TranslationModel/UG/mmsapt_align.cc @@ -127,6 +127,7 @@ namespace Moses Alignment:: show(ostream& out, PhraseAlnHyp const& ah) { +#if 0 LexicalPhraseScorer2::table_t const& COOCjnt = PT.calc_lex.scorer.COOC; @@ -164,6 +165,7 @@ namespace Moses // << " jbwd: " << obwdj[po_jbwd]<<"/"<first); if (R == tpid2span.end()) continue; pp.update(y->first, y->second); - PT.calc_lex(PT.btfix,pp); - PT.calc_pfwd_fix(PT.btfix,pp); - PT.calc_pbwd_fix(PT.btfix,pp); + PT.ScorePPfix(pp); pp.eval(PT.feature_weights); PP.push_back(pp); BOOST_FOREACH(span const& sspan, L->second) @@ -329,6 +329,7 @@ namespace Moses BOOST_FOREACH(int i, o) A.show(cout,A.PAH[i]); sptr > aln; return aln; - } +} } + diff --git a/moses/TranslationModel/UG/mmsapt_phrase_scorers.h b/moses/TranslationModel/UG/mmsapt_phrase_scorers.h new file mode 100644 index 000000000..6e852b44b --- /dev/null +++ b/moses/TranslationModel/UG/mmsapt_phrase_scorers.h @@ -0,0 +1,318 @@ +// -*- c++ -*- +#pragma once +#include "moses/TranslationModel/UG/mm/ug_bitext.h" +#include "util/exception.hh" + +namespace Moses { + namespace bitext + { + + template + class + PhraseScorer + { + protected: + int m_index; + int m_num_feats; + vector m_feature_names; + public: + + virtual + void + operator()(Bitext const& pt, PhrasePair& pp, vector * dest=NULL) + const = 0; + + int + fcnt() const + { return m_num_feats; } + + vector const & + fnames() const + { return m_feature_names; } + + string const & + fname(int i) const + { + UTIL_THROW_IF2((i < m_index || i >= m_index + m_num_feats), + "Feature name index out of range at " + << __FILE__ << ":" << __LINE__); + return m_feature_names.at(i - m_index); + } + + int + getIndex() const + { return m_index; } + }; + + //////////////////////////////////////////////////////////////////////////////// + + template + class + PScorePfwd : public PhraseScorer + { + float conf; + char denom; + public: + PScorePfwd() + { + this->m_num_feats = 1; + } + + int + init(int const i, float const c, char d) + { + conf = c; + denom = d; + this->m_index = i; + ostringstream buf; + buf << format("pfwd-%c%.3f") % denom % c; + this->m_feature_names.push_back(buf.str()); + return i + this->m_num_feats; + } + + void + operator()(Bitext const& bt, PhrasePair & pp, + vector * dest = NULL) const + { + if (!dest) dest = &pp.fvals; + if (pp.joint > pp.good1) + { + cerr<m_index] = log(lbop(pp.good1, pp.joint, conf)); + break; + case 's': + (*dest)[this->m_index] = log(lbop(pp.sample1, pp.joint, conf)); + break; + case 'r': + (*dest)[this->m_index] = log(lbop(pp.raw1, pp.joint, conf)); + } + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + template + class + PScorePbwd : public PhraseScorer + { + float conf; + char denom; + public: + PScorePbwd() + { + this->m_num_feats = 1; + } + + int + init(int const i, float const c, char d) + { + conf = c; + denom = d; + this->m_index = i; + ostringstream buf; + buf << format("pbwd-%c%.3f") % denom % c; + this->m_feature_names.push_back(buf.str()); + return i + this->m_num_feats; + } + + void + operator()(Bitext const& bt, PhrasePair& pp, + vector * dest = NULL) const + { + if (!dest) dest = &pp.fvals; + // we use the denominator specification to scale the raw counts on the + // target side; the clean way would be to counter-sample + uint32_t r2 = pp.raw2; + if (denom == 'g') r2 = round(r2 * float(pp.good1) / pp.raw1); + else if (denom == 's') r2 = round(r2 * float(pp.sample1) / pp.raw1); + (*dest)[this->m_index] = log(lbop(max(r2, pp.joint),pp.joint,conf)); + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + template + class + PScoreCoherence : public PhraseScorer + { + public: + PScoreCoherence() + { + this->m_num_feats = 1; + } + + int + init(int const i) + { + this->m_index = i; + this->m_feature_names.push_back(string("coherence")); + return i + this->m_num_feats; + } + + void + operator()(Bitext const& bt, PhrasePair& pp, + vector * dest = NULL) const + { + if (!dest) dest = &pp.fvals; + (*dest)[this->m_index] = log(pp.good1) - log(pp.sample1); + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + template + class + PScoreLogCounts : public PhraseScorer + { + float conf; + public: + PScoreLogCounts() + { + this->m_num_feats = 5; + } + + int + init(int const i) + { + this->m_index = i; + this->m_feature_names.push_back("log-r1"); + this->m_feature_names.push_back("log-s1"); + this->m_feature_names.push_back("log-g1"); + this->m_feature_names.push_back("log-j"); + this->m_feature_names.push_back("log-r2"); + return i + this->m_num_feats; + } + + void + operator()(Bitext const& bt, PhrasePair& pp, + vector * dest = NULL) const + { + if (!dest) dest = &pp.fvals; + size_t i = this->m_index; + assert(pp.raw1); + assert(pp.sample1); + assert(pp.good1); + assert(pp.joint); + assert(pp.raw2); + (*dest)[i] = -log(pp.raw1); + (*dest)[++i] = -log(pp.sample1); + (*dest)[++i] = -log(pp.good1); + (*dest)[++i] = +log(pp.joint); + (*dest)[++i] = -log(pp.raw2); + } + }; + + template + class + PScoreLex : public PhraseScorer + { + float const m_alpha; + public: + LexicalPhraseScorer2 scorer; + + PScoreLex(float const a) + : m_alpha(a) + { this->m_num_feats = 2; } + + int + init(int const i, string const& fname) + { + scorer.open(fname); + this->m_index = i; + this->m_feature_names.push_back("lexfwd"); + this->m_feature_names.push_back("lexbwd"); + return i + this->m_num_feats; + } + + void + operator()(Bitext const& bt, PhrasePair& pp, vector * dest = NULL) const + { + if (!dest) dest = &pp.fvals; + uint32_t sid1=0,sid2=0,off1=0,off2=0,len1=0,len2=0; + parse_pid(pp.p1, sid1, off1, len1); + parse_pid(pp.p2, sid2, off2, len2); + +#if 0 + cout << len1 << " " << len2 << endl; + Token const* t1 = bt.T1->sntStart(sid1); + for (size_t i = off1; i < off1 + len1; ++i) + cout << (*bt.V1)[t1[i].id()] << " "; + cout << __FILE__ << ":" << __LINE__ << endl; + + Token const* t2 = bt.T2->sntStart(sid2); + for (size_t i = off2; i < off2 + len2; ++i) + cout << (*bt.V2)[t2[i].id()] << " "; + cout << __FILE__ << ":" << __LINE__ << endl; + + BOOST_FOREACH (int a, pp.aln) + cout << a << " " ; + cout << __FILE__ << ":" << __LINE__ << "\n" << endl; + +#endif + scorer.score(bt.T1->sntStart(sid1)+off1,0,len1, + bt.T2->sntStart(sid2)+off2,0,len2, + pp.aln, m_alpha, + (*dest)[this->m_index], + (*dest)[this->m_index+1]); + } + + }; + + /// Word penalty + template + class + PScoreWP : public PhraseScorer + { + public: + + PScoreWP() { this->m_num_feats = 1; } + + int + init(int const i) + { + this->m_index = i; + return i + this->m_num_feats; + } + + void + operator()(Bitext const& bt, PhrasePair& pp, vector * dest = NULL) const + { + if (!dest) dest = &pp.fvals; + uint32_t sid2=0,off2=0,len2=0; + parse_pid(pp.p2, sid2, off2, len2); + (*dest)[this->m_index] = len2; + } + + }; + + /// Phrase penalty + template + class + PScorePP : public PhraseScorer + { + public: + + PScorePP() { this->m_num_feats = 1; } + + int + init(int const i) + { + this->m_index = i; + return i + this->m_num_feats; + } + + void + operator()(Bitext const& bt, PhrasePair& pp, vector * dest = NULL) const + { + if (!dest) dest = &pp.fvals; + (*dest)[this->m_index] = 1; + } + + }; + } +} diff --git a/moses/TranslationModel/UG/ptable-lookup.cc b/moses/TranslationModel/UG/ptable-lookup.cc new file mode 100644 index 000000000..106505f05 --- /dev/null +++ b/moses/TranslationModel/UG/ptable-lookup.cc @@ -0,0 +1,127 @@ +#include "mmsapt.h" +#include "moses/TranslationModel/PhraseDictionaryTreeAdaptor.h" +#include +#include +#include +#include +#include +#include + +using namespace Moses; +using namespace bitext; +using namespace std; +using namespace boost; + +vector fo(1,FactorType(0)); + +class SimplePhrase : public Moses::Phrase +{ + vector const m_fo; // factor order +public: + SimplePhrase(): m_fo(1,FactorType(0)) {} + + void init(string const& s) + { + istringstream buf(s); string w; + while (buf >> w) + { + Word wrd; + this->AddWord().CreateFromString(Input,m_fo,StringPiece(w),false,false); + } + } +}; + +class TargetPhraseIndexSorter +{ + TargetPhraseCollection const& my_tpc; + CompareTargetPhrase cmp; +public: + TargetPhraseIndexSorter(TargetPhraseCollection const& tpc) : my_tpc(tpc) {} + bool operator()(size_t a, size_t b) const + { + return cmp(*my_tpc[a], *my_tpc[b]); + } +}; + +int main(int argc, char* argv[]) +{ + Parameter params; + if (!params.LoadParam(argc,argv) || !StaticData::LoadDataStatic(¶ms, argv[0])) + exit(1); + + StaticData const& global = StaticData::Instance(); + global.SetVerboseLevel(0); + vector ifo = global.GetInputFactorOrder(); + + PhraseDictionary* PT = PhraseDictionary::GetColl()[0]; + Mmsapt* mmsapt = dynamic_cast(PT); + PhraseDictionaryTreeAdaptor* pdta = dynamic_cast(PT); + // vector const& ffs = FeatureFunction::GetFeatureFunctions(); + + if (!mmsapt && !pdta) + { + cerr << "Phrase table implementation not supported by this utility." << endl; + exit(1); + } + + string line; + while (true) + { + Sentence phrase; + if (!phrase.Read(cin,ifo)) break; + if (pdta) + { + pdta->InitializeForInput(phrase); + // do we also need to call CleanupAfterSentenceProcessing at the end? + } + Phrase& p = phrase; + + cout << p << endl; + TargetPhraseCollection const* trg = PT->GetTargetPhraseCollectionLEGACY(p); + if (!trg) continue; + vector order(trg->GetSize()); + for (size_t i = 0; i < order.size(); ++i) order[i] = i; + sort(order.begin(),order.end(),TargetPhraseIndexSorter(*trg)); + size_t k = 0; + // size_t precision = + cout.precision(2); + + vector fname; + if (mmsapt) + { + fname = mmsapt->GetFeatureNames(); + cout << " "; + BOOST_FOREACH(string const& fn, fname) + cout << " " << format("%10.10s") % fn; + cout << endl; + } + + BOOST_FOREACH(size_t i, order) + { + Phrase const& phr = static_cast(*(*trg)[i]); + cout << setw(3) << ++k << " " << phr << endl; + ScoreComponentCollection const& scc = (*trg)[i]->GetScoreBreakdown(); + ScoreComponentCollection::IndexPair idx = scc.GetIndexes(PT); + FVector const& scores = scc.GetScoresVector(); + cout << " "; + for (size_t k = idx.first; k < idx.second; ++k) + { + if (mmsapt && fname[k-idx.first].substr(0,3) == "log") + { + if(scores[k] < 0) + cout << " " << format("%10d") % round(exp(-scores[k])); + else + cout << " " << format("%10d") % round(exp(scores[k])); + } + else + cout << " " << format("%10.8f") % exp(scores[k]); + } + cout << endl; + } + PT->Release(trg); + } + exit(0); +} + + + diff --git a/moses/TranslationModel/UG/util/Makefile b/moses/TranslationModel/UG/util/Makefile new file mode 100644 index 000000000..afe8c7b86 --- /dev/null +++ b/moses/TranslationModel/UG/util/Makefile @@ -0,0 +1,7 @@ +# -*- makefile -*- + +MOSES_CODE=/fs/gna0/germann/code/mosesdecoder +MOSES_ROOT=/fs/gna0/germann/moses +LIBS = $(addprefix -l,moses icuuc icuio icui18n boost_iostreams) +ibm1-align: ibm1-align.cc + g++ -o $@ -L ${MOSES_ROOT}/lib -I ${MOSES_CODE} $^ ${LIBS} -ggdb \ No newline at end of file diff --git a/moses/TranslationModel/UG/util/ibm1-align b/moses/TranslationModel/UG/util/ibm1-align new file mode 100755 index 000000000..2352dadb9 Binary files /dev/null and b/moses/TranslationModel/UG/util/ibm1-align differ diff --git a/moses/TranslationModel/UG/util/ibm1-align.cc b/moses/TranslationModel/UG/util/ibm1-align.cc new file mode 100644 index 000000000..08ac1f89b --- /dev/null +++ b/moses/TranslationModel/UG/util/ibm1-align.cc @@ -0,0 +1,164 @@ +// -*- c++ -*- +// Parallel text alignment via IBM1 / raw counts of word alignments +// aiming at high precision (to seed Yawat alignments) +// This program is tailored for use with Yawat. +// Written by Ulrich Germann. + +#include + +#include +#include +#include +#include +#include + +#include "moses/TranslationModel/UG/generic/file_io/ug_stream.h" +#include "moses/TranslationModel/UG/mm/tpt_tokenindex.h" +#include +#include "moses/TranslationModel/UG/mm/tpt_pickler.h" +#include "moses/TranslationModel/UG/mm/ug_mm_2d_table.h" + +using namespace std; +using namespace ugdiss; + +typedef mm2dTable table_t; + +class IBM1 +{ +public: + table_t COOC; + TokenIndex V1,V2; + + void + align(string const& s1, string const& s2, vector& aln) const; + + void + align(vector const& x1, + vector const& x2, + vector& aln) const; + + void + fill_amatrix(vector const& x1, + vector const& x2, + vector >& aln) const; + + void + open(string const base, string const L1, string const L2); +}; + +void +IBM1:: +open(string const base, string const L1, string const L2) +{ + V1.open(base+L1+".tdx"); + V2.open(base+L2+".tdx"); + COOC.open(base+L1+"-"+L2+".lex"); +} + +void +IBM1:: +align(string const& s1, string const& s2, vector& aln) const +{ + vector x1,x2; + V1.fillIdSeq(s1,x1); + V2.fillIdSeq(s2,x2); + align(x1,x2,aln); +} + + static UnicodeString apos = UnicodeString::fromUTF8(StringPiece("'")); + +string +u(StringPiece str, size_t start, size_t stop) +{ + string ret; + UnicodeString::fromUTF8(str).tempSubString(start,stop).toUTF8String(ret); + return ret; +} + +void +IBM1:: +fill_amatrix(vector const& x1, + vector const& x2, + vector >& aln) const +{ + aln.assign(x1.size(),vector(x2.size())); + for (size_t i = 0; i < x1.size(); ++i) + for (size_t k = 0; k < x2.size(); ++k) + aln[i][k] = COOC[x1[i]][x2[k]]; +#if 0 + cout << setw(10) << " "; + for (size_t k = 0; k < x2.size(); ++k) + cout << setw(7) << right << u(V2[x2[k]],0,6); + cout << endl; + for (size_t i = 0; i < x1.size(); ++i) + { + cout << setw(10) << u(V1[x1[i]],0,10); + for (size_t k = 0; k < x2.size(); ++k) + { + if (aln[i][k] > 999999) + cout << setw(7) << aln[i][k]/1000 << " K"; + else + cout << setw(7) << aln[i][k]; + } + cout << endl; + } +#endif +} + + +void +IBM1:: +align(vector const& x1, + vector const& x2, + vector& aln) const +{ + vector > M; + // fill_amatrix(x1,x2,M); + vector i1(x1.size(),0), max1(x1.size(),0); + vector i2(x2.size(),0), max2(x2.size(),0); + aln.clear(); + for (size_t i = 0; i < i1.size(); ++i) + { + for (size_t k = 0; k < i2.size(); ++k) + { + int c = COOC[x1[i]][x2[k]]; + if (c > max1[i]) { i1[i] = k; max1[i] = c; } + if (c >= max2[k]) { i2[k] = i; max2[k] = c; } + } + } + for (size_t i = 0; i < i1.size(); ++i) + { + if (max1[i] && i2[i1[i]] == i) + { + aln.push_back(i); + aln.push_back(i1[i]); + } + } +} + +int main(int argc, char* argv[]) +{ + IBM1 ibm1; + ibm1.open(argv[1],argv[2],argv[3]); + string line1,line2,sid; + while (getline(cin,sid)) + { + if (!getline(cin,line1)) assert(false); + if (!getline(cin,line2)) assert(false); + vector a; + vector s1,s2; + ibm1.V1.fillIdSeq(line1,s1); + ibm1.V2.fillIdSeq(line2,s2); + ibm1.align(s1,s2,a); + cout << sid; + for (size_t i = 0; i < a.size(); i += 2) + cout << " " << a[i] << ":" << a[i+1] << ":unspec"; + cout << endl; + // cout << line1 << endl; + // cout << line2 << endl; + // for (size_t i = 0; i < a.size(); i += 2) + // cout << ibm1.V1[s1[a[i]]] << " - " + // << ibm1.V2[s2[a[i+1]]] << endl; + } + // cout << endl; +} diff --git a/moses/TranslationModel/UG/util/tokenindex.dump.cc b/moses/TranslationModel/UG/util/tokenindex.dump.cc new file mode 100644 index 000000000..55970dbf0 --- /dev/null +++ b/moses/TranslationModel/UG/util/tokenindex.dump.cc @@ -0,0 +1,31 @@ +// (c) 2007,2008 Ulrich Germann +// Licensed to NRC-CNRC under special agreement. + +/** + * @author Ulrich Germann + * @file tokenindex.dump.cc + * @brief Dumps a TokenIndex (vocab file for TPPT and TPLM) to stdout. + */ + +#include "tpt_tokenindex.h" +#include +#include + +using namespace std; +using namespace ugdiss; +int +main(int argc,char* argv[]) +{ + if (argc > 1 && !strcmp(argv[1], "-h")) { + printf("Usage: %s \n\n", argv[0]); + cout << "Converts a phrase table in text format to a phrase table in tighly packed format." << endl; + cout << "input file: token index file" << endl; + exit(1); + } + + TokenIndex I; + I.open(argv[1]); + vector foo = I.reverseIndex(); + for (size_t i = 0; i < foo.size(); i++) + cout << setw(10) << i << " " << foo[i] << endl; +} diff --git a/moses/TranslationModel/fuzzy-match/FuzzyMatchWrapper.cpp b/moses/TranslationModel/fuzzy-match/FuzzyMatchWrapper.cpp index fc68e1f0d..8766743b3 100644 --- a/moses/TranslationModel/fuzzy-match/FuzzyMatchWrapper.cpp +++ b/moses/TranslationModel/fuzzy-match/FuzzyMatchWrapper.cpp @@ -413,11 +413,9 @@ void FuzzyMatchWrapper::load_corpus( const std::string &fileName, vector< vector istream *fileStreamP = &fileStream; - char line[LINE_MAX_LENGTH]; - while(true) { - SAFE_GETLINE((*fileStreamP), line, LINE_MAX_LENGTH, '\n'); - if (fileStreamP->eof()) break; - corpus.push_back( GetVocabulary().Tokenize( line ) ); + string line; + while(getline(*fileStreamP, line)) { + corpus.push_back( GetVocabulary().Tokenize( line.c_str() ) ); } } @@ -436,12 +434,9 @@ void FuzzyMatchWrapper::load_target(const std::string &fileName, vector< vector< WORD_ID delimiter = GetVocabulary().StoreIfNew("|||"); int lineNum = 0; - char line[LINE_MAX_LENGTH]; - while(true) { - SAFE_GETLINE((*fileStreamP), line, LINE_MAX_LENGTH, '\n'); - if (fileStreamP->eof()) break; - - vector toks = GetVocabulary().Tokenize( line ); + string line; + while(getline(*fileStreamP, line)) { + vector toks = GetVocabulary().Tokenize( line.c_str() ); corpus.push_back(vector< SentenceAlignment >()); vector< SentenceAlignment > &vec = corpus.back(); @@ -493,11 +488,8 @@ void FuzzyMatchWrapper::load_alignment(const std::string &fileName, vector< vect string delimiter = "|||"; int lineNum = 0; - char line[LINE_MAX_LENGTH]; - while(true) { - SAFE_GETLINE((*fileStreamP), line, LINE_MAX_LENGTH, '\n'); - if (fileStreamP->eof()) break; - + string line; + while(getline(*fileStreamP, line)) { vector< SentenceAlignment > &vec = corpus[lineNum]; size_t targetInd = 0; SentenceAlignment *sentence = &vec[targetInd]; diff --git a/moses/TranslationModel/fuzzy-match/SuffixArray.cpp b/moses/TranslationModel/fuzzy-match/SuffixArray.cpp index 536bff741..2930147ab 100644 --- a/moses/TranslationModel/fuzzy-match/SuffixArray.cpp +++ b/moses/TranslationModel/fuzzy-match/SuffixArray.cpp @@ -14,17 +14,16 @@ SuffixArray::SuffixArray( string fileName ) m_endOfSentence = m_vcb.StoreIfNew( "" ); ifstream extractFile; - char line[LINE_MAX_LENGTH]; // count the number of words first; extractFile.open(fileName.c_str()); istream *fileP = &extractFile; m_size = 0; size_t sentenceCount = 0; - while(!fileP->eof()) { - SAFE_GETLINE((*fileP), line, LINE_MAX_LENGTH, '\n'); - if (fileP->eof()) break; - vector< WORD_ID > words = m_vcb.Tokenize( line ); + string line; + while(getline(*fileP, line)) { + + vector< WORD_ID > words = m_vcb.Tokenize( line.c_str() ); m_size += words.size() + 1; sentenceCount++; } @@ -43,10 +42,8 @@ SuffixArray::SuffixArray( string fileName ) int sentenceId = 0; extractFile.open(fileName.c_str()); fileP = &extractFile; - while(!fileP->eof()) { - SAFE_GETLINE((*fileP), line, LINE_MAX_LENGTH, '\n'); - if (fileP->eof()) break; - vector< WORD_ID > words = m_vcb.Tokenize( line ); + while(getline(*fileP, line)) { + vector< WORD_ID > words = m_vcb.Tokenize( line.c_str() ); // add to corpus vector corpus.push_back(words); diff --git a/moses/TranslationModel/fuzzy-match/Vocabulary.h b/moses/TranslationModel/fuzzy-match/Vocabulary.h index dfa11c1db..5a79e2f26 100644 --- a/moses/TranslationModel/fuzzy-match/Vocabulary.h +++ b/moses/TranslationModel/fuzzy-match/Vocabulary.h @@ -17,20 +17,6 @@ namespace tmmt { - -#define MAX_LENGTH 10000 - -#define SAFE_GETLINE(_IS, _LINE, _SIZE, _DELIM) { \ - _IS.getline(_LINE, _SIZE, _DELIM); \ - if(_IS.fail() && !_IS.bad() && !_IS.eof()) _IS.clear(); \ - if (_IS.gcount() == _SIZE-1) { \ - cerr << "Line too long! Buffer overflow. Delete lines >=" \ - << _SIZE << " chars or raise MAX_LENGTH in phrase-extract/tables-core.cpp" \ - << endl; \ - exit(1); \ - } \ - } - typedef std::string WORD; typedef unsigned int WORD_ID; diff --git a/moses/TreeInput.cpp b/moses/TreeInput.cpp index 65c5c0b42..2b246aee5 100644 --- a/moses/TreeInput.cpp +++ b/moses/TreeInput.cpp @@ -145,8 +145,12 @@ bool TreeInput::ProcessAndStripXMLTags(string &line, std::vector VERBOSE(3,"XML TAG " << tagName << " (" << tagContent << ") spanning " << startPos << " to " << (endPos-1) << " complete, commence processing" << endl); - if (startPos >= endPos) { - TRACE_ERR("ERROR: tag " << tagName << " must span at least one word: " << line << endl); + if (startPos == endPos) { + TRACE_ERR("WARNING: tag " << tagName << " span is empty. Ignoring: " << line << endl); + continue; + } + else if (startPos > endPos) { + TRACE_ERR("ERROR: tag " << tagName << " startPos > endPos: " << line << endl); return false; } @@ -266,7 +270,10 @@ int TreeInput::Read(std::istream& in,const std::vector& factorOrder) // default label for (size_t startPos = 0; startPos < sourceSize; ++startPos) { for (size_t endPos = startPos; endPos < sourceSize; ++endPos) { - AddChartLabel(startPos, endPos, staticData.GetInputDefaultNonTerminal(), factorOrder); + NonTerminalSet &list = GetLabelSet(startPos, endPos); + if (list.size() == 0 || !staticData.GetDefaultNonTermOnlyForEmptyRange()) { + AddChartLabel(startPos, endPos, staticData.GetInputDefaultNonTerminal(), factorOrder); + } } } diff --git a/moses/Word.cpp b/moses/Word.cpp index 04cbdb6a7..b1ea77059 100644 --- a/moses/Word.cpp +++ b/moses/Word.cpp @@ -139,8 +139,7 @@ CreateFromString(FactorDirection direction << " contains factor delimiter " << StaticData::Instance().GetFactorDelimiter() << " too many times."); - - UTIL_THROW_IF(i < factorOrder.size(),util::Exception, + UTIL_THROW_IF(!isNonTerminal && i < factorOrder.size(),util::Exception, "Too few factors in string '" << str << "'."); } else diff --git a/moses/WordLattice.cpp b/moses/WordLattice.cpp index 7153f2d93..6be229491 100644 --- a/moses/WordLattice.cpp +++ b/moses/WordLattice.cpp @@ -56,7 +56,7 @@ InitializeFromPCNDataType const std::vector& factorOrder, const std::string& debug_line) { - const StaticData &staticData = StaticData::Instance(); + // const StaticData &staticData = StaticData::Instance(); const InputFeature &inputFeature = InputFeature::Instance(); size_t numInputScores = inputFeature.GetNumInputScores(); size_t numRealWordCount = inputFeature.GetNumRealWordsInInput(); diff --git a/moses/XmlOption.cpp b/moses/XmlOption.cpp index 21ce0a411..c42b200de 100644 --- a/moses/XmlOption.cpp +++ b/moses/XmlOption.cpp @@ -308,10 +308,14 @@ bool ProcessAndStripXMLTags(string &line, vector &res, ReorderingCon // default: opening tag that specifies translation options else { - if (startPos >= endPos) { - TRACE_ERR("ERROR: tag " << tagName << " must span at least one word: " << line << endl); + if (startPos > endPos) { + TRACE_ERR("ERROR: tag " << tagName << " startPos > endPos: " << line << endl); return false; } + else if (startPos == endPos) { + TRACE_ERR("WARNING: tag " << tagName << " 0 span: " << line << endl); + continue; + } // specified translations -> vector of phrases // multiple translations may be specified, separated by "||" diff --git a/phrase-extract/DomainFeature.cpp b/phrase-extract/DomainFeature.cpp index 2f99a8709..99f0713a7 100644 --- a/phrase-extract/DomainFeature.cpp +++ b/phrase-extract/DomainFeature.cpp @@ -2,9 +2,6 @@ #include "ExtractionPhrasePair.h" #include "tables-core.h" #include "InputFileStream.h" -#include "SafeGetline.h" - -#define TABLE_LINE_MAX_LENGTH 1000 using namespace std; @@ -16,12 +13,11 @@ void Domain::load( const std::string &domainFileName ) { Moses::InputFileStream fileS( domainFileName ); istream *fileP = &fileS; - while(true) { - char line[TABLE_LINE_MAX_LENGTH]; - SAFE_GETLINE((*fileP), line, TABLE_LINE_MAX_LENGTH, '\n', __FILE__); - if (fileP->eof()) break; + + string line; + while(getline(*fileP, line)) { // read - vector< string > domainSpecLine = tokenize( line ); + vector< string > domainSpecLine = tokenize( line.c_str() ); int lineNumber; if (domainSpecLine.size() != 2 || ! sscanf(domainSpecLine[0].c_str(), "%d", &lineNumber)) { diff --git a/phrase-extract/ExtractionPhrasePair.cpp b/phrase-extract/ExtractionPhrasePair.cpp index f70d106d1..9564b1cfe 100644 --- a/phrase-extract/ExtractionPhrasePair.cpp +++ b/phrase-extract/ExtractionPhrasePair.cpp @@ -19,7 +19,6 @@ #include #include "ExtractionPhrasePair.h" -#include "SafeGetline.h" #include "tables-core.h" #include "score.h" #include "moses/Util.h" @@ -322,5 +321,148 @@ std::string ExtractionPhrasePair::CollectAllPropertyValues(const std::string &ke } +std::string ExtractionPhrasePair::CollectAllLabelsSeparateLHSAndRHS(const std::string& propertyKey, + std::set& labelSet, + boost::unordered_map& countsLabelsLHS, + boost::unordered_map* >& jointCountsRulesTargetLHSAndLabelsLHS, + Vocabulary &vcbT) const +{ + const PROPERTY_VALUES *allPropertyValues = GetProperty( propertyKey ); + + if ( allPropertyValues == NULL ) { + return ""; + } + + std::string lhs="", rhs="", currentRhs=""; + float currentRhsCount = 0.0; + std::list< std::pair > lhsGivenCurrentRhsCounts; + + std::ostringstream oss; + for (PROPERTY_VALUES::const_iterator iter=allPropertyValues->begin(); + iter!=allPropertyValues->end(); ++iter) { + + size_t space = (iter->first).find_last_of(' '); + if ( space == string::npos ) { + lhs = iter->first; + rhs.clear(); + } else { + lhs = (iter->first).substr(space+1); + rhs = (iter->first).substr(0,space); + } + + labelSet.insert(lhs); + + if ( rhs.compare(currentRhs) ) { + + if ( iter!=allPropertyValues->begin() ) { + if ( !currentRhs.empty() ) { + istringstream tokenizer(currentRhs); + std::string rhsLabel; + while ( tokenizer.peek() != EOF ) { + tokenizer >> rhsLabel; + labelSet.insert(rhsLabel); + } + oss << " " << currentRhs << " " << currentRhsCount; + } + if ( lhsGivenCurrentRhsCounts.size() > 0 ) { + if ( !currentRhs.empty() ) { + oss << " " << lhsGivenCurrentRhsCounts.size(); + } + for ( std::list< std::pair >::const_iterator iter2=lhsGivenCurrentRhsCounts.begin(); + iter2!=lhsGivenCurrentRhsCounts.end(); ++iter2 ) { + oss << " " << iter2->first << " " << iter2->second; + + // update countsLabelsLHS and jointCountsRulesTargetLHSAndLabelsLHS + std::string ruleTargetLhs = vcbT.getWord(m_phraseTarget->back()); + ruleTargetLhs.erase(ruleTargetLhs.begin()); // strip square brackets + ruleTargetLhs.erase(ruleTargetLhs.size()-1); + + std::pair< boost::unordered_map::iterator, bool > insertedCountsLabelsLHS = + countsLabelsLHS.insert(std::pair(iter2->first,iter2->second)); + if (!insertedCountsLabelsLHS.second) { + (insertedCountsLabelsLHS.first)->second += iter2->second; + } + + boost::unordered_map* >::iterator jointCountsRulesTargetLHSAndLabelsLHSIter = + jointCountsRulesTargetLHSAndLabelsLHS.find(ruleTargetLhs); + if ( jointCountsRulesTargetLHSAndLabelsLHSIter == jointCountsRulesTargetLHSAndLabelsLHS.end() ) { + boost::unordered_map* jointCounts = new boost::unordered_map; + jointCounts->insert(std::pair(iter2->first,iter2->second)); + jointCountsRulesTargetLHSAndLabelsLHS.insert(std::pair* >(ruleTargetLhs,jointCounts)); + } else { + boost::unordered_map* jointCounts = jointCountsRulesTargetLHSAndLabelsLHSIter->second; + std::pair< boost::unordered_map::iterator, bool > insertedJointCounts = + jointCounts->insert(std::pair(iter2->first,iter2->second)); + if (!insertedJointCounts.second) { + (insertedJointCounts.first)->second += iter2->second; + } + } + + } + } + + lhsGivenCurrentRhsCounts.clear(); + } + + currentRhsCount = 0.0; + currentRhs = rhs; + } + + currentRhsCount += iter->second; + lhsGivenCurrentRhsCounts.push_back( std::pair(lhs,iter->second) ); + } + + if ( !currentRhs.empty() ) { + istringstream tokenizer(currentRhs); + std::string rhsLabel; + while ( tokenizer.peek() != EOF ) { + tokenizer >> rhsLabel; + labelSet.insert(rhsLabel); + } + oss << " " << currentRhs << " " << currentRhsCount; + } + if ( lhsGivenCurrentRhsCounts.size() > 0 ) { + if ( !currentRhs.empty() ) { + oss << " " << lhsGivenCurrentRhsCounts.size(); + } + for ( std::list< std::pair >::const_iterator iter2=lhsGivenCurrentRhsCounts.begin(); + iter2!=lhsGivenCurrentRhsCounts.end(); ++iter2 ) { + oss << " " << iter2->first << " " << iter2->second; + + // update countsLabelsLHS and jointCountsRulesTargetLHSAndLabelsLHS + std::string ruleTargetLhs = vcbT.getWord(m_phraseTarget->back()); + ruleTargetLhs.erase(ruleTargetLhs.begin()); // strip square brackets + ruleTargetLhs.erase(ruleTargetLhs.size()-1); + + std::pair< boost::unordered_map::iterator, bool > insertedCountsLabelsLHS = + countsLabelsLHS.insert(std::pair(iter2->first,iter2->second)); + if (!insertedCountsLabelsLHS.second) { + (insertedCountsLabelsLHS.first)->second += iter2->second; + } + + boost::unordered_map* >::iterator jointCountsRulesTargetLHSAndLabelsLHSIter = + jointCountsRulesTargetLHSAndLabelsLHS.find(ruleTargetLhs); + if ( jointCountsRulesTargetLHSAndLabelsLHSIter == jointCountsRulesTargetLHSAndLabelsLHS.end() ) { + boost::unordered_map* jointCounts = new boost::unordered_map; + jointCounts->insert(std::pair(iter2->first,iter2->second)); + jointCountsRulesTargetLHSAndLabelsLHS.insert(std::pair* >(ruleTargetLhs,jointCounts)); + } else { + boost::unordered_map* jointCounts = jointCountsRulesTargetLHSAndLabelsLHSIter->second; + std::pair< boost::unordered_map::iterator, bool > insertedJointCounts = + jointCounts->insert(std::pair(iter2->first,iter2->second)); + if (!insertedJointCounts.second) { + (insertedJointCounts.first)->second += iter2->second; + } + } + + } + } + + std::string allPropertyValuesString(oss.str()); + return allPropertyValuesString; +} + + + } diff --git a/phrase-extract/ExtractionPhrasePair.h b/phrase-extract/ExtractionPhrasePair.h index f04984391..ba23ac1f2 100644 --- a/phrase-extract/ExtractionPhrasePair.h +++ b/phrase-extract/ExtractionPhrasePair.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace MosesTraining { @@ -124,6 +125,12 @@ public: std::string CollectAllPropertyValues(const std::string &key) const; + std::string CollectAllLabelsSeparateLHSAndRHS(const std::string& propertyKey, + std::set& sourceLabelSet, + boost::unordered_map& sourceLHSCounts, + boost::unordered_map* >& sourceRHSAndLHSJointCounts, + Vocabulary &vcbT) const; + void AddProperties( const std::string &str, float count ); void AddProperty( const std::string &key, const std::string &value, float count ) diff --git a/phrase-extract/SafeGetline.h b/phrase-extract/SafeGetline.h deleted file mode 100644 index 0e03b8468..000000000 --- a/phrase-extract/SafeGetline.h +++ /dev/null @@ -1,35 +0,0 @@ -/*********************************************************************** - Moses - factored phrase-based language decoder - Copyright (C) 2010 University of Edinburgh - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - ***********************************************************************/ - -#pragma once -#ifndef SAFE_GETLINE_INCLUDED_ -#define SAFE_GETLINE_INCLUDED_ - -#define SAFE_GETLINE(_IS, _LINE, _SIZE, _DELIM, _FILE) { \ - _IS.getline(_LINE, _SIZE, _DELIM); \ - if(_IS.fail() && !_IS.bad() && !_IS.eof()) _IS.clear(); \ - if (_IS.gcount() == _SIZE-1) { \ - cerr << "Line too long! Buffer overflow. Delete lines >=" \ - << _SIZE << " chars or raise LINE_MAX_LENGTH in " << _FILE \ - << endl; \ - exit(1); \ - } \ - } - -#endif diff --git a/phrase-extract/ScoreFeature.h b/phrase-extract/ScoreFeature.h index 926397e71..30e198e21 100644 --- a/phrase-extract/ScoreFeature.h +++ b/phrase-extract/ScoreFeature.h @@ -90,7 +90,7 @@ public: float count, int sentenceId) const {}; - /** Add the values for this feature function. */ + /** Add the values for this score feature. */ virtual void add(const ScoreFeatureContext& context, std::vector& denseValues, std::map& sparseValues) const = 0; diff --git a/phrase-extract/SentenceAlignment.cpp b/phrase-extract/SentenceAlignment.cpp index c3d71d525..120c9154d 100644 --- a/phrase-extract/SentenceAlignment.cpp +++ b/phrase-extract/SentenceAlignment.cpp @@ -54,7 +54,11 @@ bool SentenceAlignment::processSourceSentence(const char * sourceString, int, bo return true; } -bool SentenceAlignment::create( char targetString[], char sourceString[], char alignmentString[], char weightString[], int sentenceID, bool boundaryRules) +bool SentenceAlignment::create(const char targetString[], + const char sourceString[], + const char alignmentString[], + const char weightString[], + int sentenceID, bool boundaryRules) { using namespace std; this->sentenceID = sentenceID; diff --git a/phrase-extract/SentenceAlignment.h b/phrase-extract/SentenceAlignment.h index 1df61cf02..576d3279e 100644 --- a/phrase-extract/SentenceAlignment.h +++ b/phrase-extract/SentenceAlignment.h @@ -43,8 +43,11 @@ public: virtual bool processSourceSentence(const char *, int, bool boundaryRules); - bool create(char targetString[], char sourceString[], - char alignmentString[], char weightString[], int sentenceID, bool boundaryRules); + bool create(const char targetString[], + const char sourceString[], + const char alignmentString[], + const char weightString[], + int sentenceID, bool boundaryRules); void invertAlignment(); diff --git a/phrase-extract/XmlTree.cpp b/phrase-extract/XmlTree.cpp index dcb974bef..ce7e6837e 100644 --- a/phrase-extract/XmlTree.cpp +++ b/phrase-extract/XmlTree.cpp @@ -248,7 +248,6 @@ bool ProcessAndStripXMLTags(string &line, SyntaxTree &tree, set< string > &label string cleanLine; // return string (text without xml) size_t wordPos = 0; // position in sentence (in terms of number of words) - bool isLinked = false; // loop through the tokens for (size_t xmlTokenPos = 0 ; xmlTokenPos < xmlTokens.size() ; xmlTokenPos++) { @@ -354,10 +353,14 @@ bool ProcessAndStripXMLTags(string &line, SyntaxTree &tree, set< string > &label // cerr << "XML TAG " << tagName << " (" << tagContent << ") spanning " << startPos << " to " << (endPos-1) << " complete, commence processing" << endl; - if (startPos >= endPos) { - cerr << "ERROR: tag " << tagName << " must span at least one word (" << startPos << "-" << endPos << "): " << line << endl; + if (startPos > endPos) { + cerr << "ERROR: tag " << tagName << " startPos is bigger than endPos (" << startPos << "-" << endPos << "): " << line << endl; return false; } + else if (startPos == endPos) { + cerr << "WARNING: tag " << tagName << ". Ignoring 0 span (" << startPos << "-" << endPos << "): " << line << endl; + continue; + } string label = ParseXmlTagAttribute(tagContent,"label"); labelCollection.insert( label ); diff --git a/phrase-extract/consolidate-direct-main.cpp b/phrase-extract/consolidate-direct-main.cpp index 3b38f741c..40e0e35d4 100644 --- a/phrase-extract/consolidate-direct-main.cpp +++ b/phrase-extract/consolidate-direct-main.cpp @@ -26,16 +26,9 @@ #include "InputFileStream.h" #include "OutputFileStream.h" -#include "SafeGetline.h" - -#define LINE_MAX_LENGTH 10000 - using namespace std; -char line[LINE_MAX_LENGTH]; - - -vector< string > splitLine() +vector< string > splitLine(const char *line) { vector< string > item; int start=0; @@ -61,14 +54,15 @@ bool getLine( istream &fileP, vector< string > &item ) { if (fileP.eof()) return false; - - SAFE_GETLINE((fileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (fileP.eof()) + + string line; + if (getline(fileP, line)) { + item = splitLine(line.c_str()); return false; - - item = splitLine(); - - return true; + } + else { + return false; + } } diff --git a/phrase-extract/consolidate-main.cpp b/phrase-extract/consolidate-main.cpp index de0d7f646..a2174805c 100644 --- a/phrase-extract/consolidate-main.cpp +++ b/phrase-extract/consolidate-main.cpp @@ -26,12 +26,9 @@ #include #include "tables-core.h" -#include "SafeGetline.h" #include "InputFileStream.h" #include "OutputFileStream.h" -#define LINE_MAX_LENGTH 10000 - using namespace std; bool hierarchicalFlag = false; @@ -46,12 +43,11 @@ inline float maybeLogProb( float a ) return logProbFlag ? log(a) : a; } -char line[LINE_MAX_LENGTH]; void processFiles( char*, char*, char*, char* ); void loadCountOfCounts( char* ); void breakdownCoreAndSparse( string combined, string &core, string &sparse ); bool getLine( istream &fileP, vector< string > &item ); -vector< string > splitLine(); +vector< string > splitLine(const char *line); vector< int > countBin; bool sparseCountBinFeatureFlag = false; @@ -140,14 +136,13 @@ void loadCountOfCounts( char* fileNameCountOfCounts ) istream &fileP = fileCountOfCounts; countOfCounts.push_back(0.0); - while(1) { - if (fileP.eof()) break; - SAFE_GETLINE((fileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (fileP.eof()) break; + + string line; + while (getline(fileP, line)) { if (totalCount < 0) - totalCount = atof(line); // total number of distinct phrase pairs + totalCount = atof(line.c_str()); // total number of distinct phrase pairs else - countOfCounts.push_back( atof(line) ); + countOfCounts.push_back( atof(line.c_str()) ); } fileCountOfCounts.Close(); @@ -337,8 +332,9 @@ void processFiles( char* fileNameDirect, char* fileNameIndirect, char* fileNameC } // arbitrary key-value pairs + fileConsolidated << " ||| "; if (itemDirect.size() >= 6) { - fileConsolidated << " ||| " << itemDirect[5]; + fileConsolidated << itemDirect[5]; } fileConsolidated << endl; @@ -370,16 +366,16 @@ bool getLine( istream &fileP, vector< string > &item ) if (fileP.eof()) return false; - SAFE_GETLINE((fileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (fileP.eof()) + string line; + if (!getline(fileP, line)) return false; - item = splitLine(); + item = splitLine(line.c_str()); return true; } -vector< string > splitLine() +vector< string > splitLine(const char *line) { vector< string > item; int start=0; diff --git a/phrase-extract/consolidate-reverse-main.cpp b/phrase-extract/consolidate-reverse-main.cpp index 6843bf3aa..ce59315b9 100644 --- a/phrase-extract/consolidate-reverse-main.cpp +++ b/phrase-extract/consolidate-reverse-main.cpp @@ -27,23 +27,19 @@ #include #include "tables-core.h" -#include "SafeGetline.h" #include "InputFileStream.h" -#define LINE_MAX_LENGTH 10000 - using namespace std; bool hierarchicalFlag = false; bool onlyDirectFlag = false; bool phraseCountFlag = true; bool logProbFlag = false; -char line[LINE_MAX_LENGTH]; void processFiles( char*, char*, char* ); bool getLine( istream &fileP, vector< string > &item ); string reverseAlignment(const string &alignments); -vector< string > splitLine(); +vector< string > splitLine(const char *lin); inline void Tokenize(std::vector &output , const std::string& str @@ -190,17 +186,18 @@ bool getLine( istream &fileP, vector< string > &item ) { if (fileP.eof()) return false; - - SAFE_GETLINE((fileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (fileP.eof()) + + string line; + if (getline(fileP, line)) { + item = splitLine(line.c_str()); return false; - - item = splitLine(); - - return true; + } + else { + return false; + } } -vector< string > splitLine() +vector< string > splitLine(const char *line) { vector< string > item; bool betweenWords = true; diff --git a/phrase-extract/extract-ghkm/ExtractGHKM.cpp b/phrase-extract/extract-ghkm/ExtractGHKM.cpp index 5b12203a5..b86c28586 100644 --- a/phrase-extract/extract-ghkm/ExtractGHKM.cpp +++ b/phrase-extract/extract-ghkm/ExtractGHKM.cpp @@ -30,6 +30,10 @@ #include "ScfgRule.h" #include "ScfgRuleWriter.h" #include "Span.h" +#include "SyntaxTree.h" +#include "tables-core.h" +#include "XmlException.h" +#include "XmlTree.h" #include "XmlTreeParser.h" #include @@ -63,7 +67,9 @@ int ExtractGHKM::Main(int argc, char *argv[]) OutputFileStream fwdExtractStream; OutputFileStream invExtractStream; std::ofstream glueGrammarStream; - std::ofstream unknownWordStream; + std::ofstream targetUnknownWordStream; + std::ofstream sourceUnknownWordStream; + std::ofstream sourceLabelSetStream; std::ofstream unknownWordSoftMatchesStream; std::string fwdFileName = options.extractFile; std::string invFileName = options.extractFile + std::string(".inv"); @@ -76,26 +82,44 @@ int ExtractGHKM::Main(int argc, char *argv[]) if (!options.glueGrammarFile.empty()) { OpenOutputFileOrDie(options.glueGrammarFile, glueGrammarStream); } - if (!options.unknownWordFile.empty()) { - OpenOutputFileOrDie(options.unknownWordFile, unknownWordStream); + if (!options.targetUnknownWordFile.empty()) { + OpenOutputFileOrDie(options.targetUnknownWordFile, targetUnknownWordStream); + } + if (!options.sourceUnknownWordFile.empty()) { + OpenOutputFileOrDie(options.sourceUnknownWordFile, sourceUnknownWordStream); + } + if (!options.sourceLabelSetFile.empty()) { + if (!options.sourceLabels) { + Error("SourceLabels should be active if SourceLabelSet is supposed to be written to a file"); + } + OpenOutputFileOrDie(options.sourceLabelSetFile, sourceLabelSetStream); // TODO: global sourceLabelSet cannot be determined during parallelized extraction } if (!options.unknownWordSoftMatchesFile.empty()) { OpenOutputFileOrDie(options.unknownWordSoftMatchesFile, unknownWordSoftMatchesStream); } // Target label sets for producing glue grammar. - std::set labelSet; - std::map topLabelSet; + std::set targetLabelSet; + std::map targetTopLabelSet; + + // Source label sets for producing glue grammar. + std::set sourceLabelSet; + std::map sourceTopLabelSet; // Word count statistics for producing unknown word labels. - std::map wordCount; - std::map wordLabel; + std::map targetWordCount; + std::map targetWordLabel; + + // Word count statistics for producing unknown word labels: source side. + std::map sourceWordCount; + std::map sourceWordLabel; std::string targetLine; std::string sourceLine; std::string alignmentLine; Alignment alignment; - XmlTreeParser xmlTreeParser(labelSet, topLabelSet); + XmlTreeParser xmlTreeParser(targetLabelSet, targetTopLabelSet); +// XmlTreeParser sourceXmlTreeParser(sourceLabelSet, sourceTopLabelSet); ScfgRuleWriter writer(fwdExtractStream, invExtractStream, options); size_t lineNum = options.sentenceOffset; while (true) { @@ -118,30 +142,71 @@ int ExtractGHKM::Main(int argc, char *argv[]) std::cerr << "skipping line " << lineNum << " with empty target tree\n"; continue; } - std::auto_ptr t; + std::auto_ptr targetParseTree; try { - t = xmlTreeParser.Parse(targetLine); - assert(t.get()); + targetParseTree = xmlTreeParser.Parse(targetLine); + assert(targetParseTree.get()); } catch (const Exception &e) { - std::ostringstream s; - s << "Failed to parse XML tree at line " << lineNum; + std::ostringstream oss; + oss << "Failed to parse target XML tree at line " << lineNum; if (!e.GetMsg().empty()) { - s << ": " << e.GetMsg(); + oss << ": " << e.GetMsg(); + } + Error(oss.str()); + } + + + // Parse source tree and construct a SyntaxTree object. + MosesTraining::SyntaxTree sourceSyntaxTree; + MosesTraining::SyntaxNode *sourceSyntaxTreeRoot=NULL; + + if (options.sourceLabels) { + try { + if (!ProcessAndStripXMLTags(sourceLine, sourceSyntaxTree, sourceLabelSet, sourceTopLabelSet, false)) { + throw Exception(""); + } + sourceSyntaxTree.ConnectNodes(); + sourceSyntaxTreeRoot = sourceSyntaxTree.GetTop(); + assert(sourceSyntaxTreeRoot); + } catch (const Exception &e) { + std::ostringstream oss; + oss << "Failed to parse source XML tree at line " << lineNum; + if (!e.GetMsg().empty()) { + oss << ": " << e.GetMsg(); + } + Error(oss.str()); } - Error(s.str()); } // Read source tokens. std::vector sourceTokens(ReadTokens(sourceLine)); + // Construct a source ParseTree object object from the SyntaxTree object. + std::auto_ptr sourceParseTree; + + if (options.sourceLabels) { + try { + sourceParseTree = XmlTreeParser::ConvertTree(*sourceSyntaxTreeRoot, sourceTokens); + assert(sourceParseTree.get()); + } catch (const Exception &e) { + std::ostringstream oss; + oss << "Failed to parse source XML tree at line " << lineNum; + if (!e.GetMsg().empty()) { + oss << ": " << e.GetMsg(); + } + Error(oss.str()); + } + } + + // Read word alignments. try { ReadAlignment(alignmentLine, alignment); } catch (const Exception &e) { - std::ostringstream s; - s << "Failed to read alignment at line " << lineNum << ": "; - s << e.GetMsg(); - Error(s.str()); + std::ostringstream oss; + oss << "Failed to read alignment at line " << lineNum << ": "; + oss << e.GetMsg(); + Error(oss.str()); } if (alignment.size() == 0) { std::cerr << "skipping line " << lineNum << " without alignment points\n"; @@ -149,13 +214,18 @@ int ExtractGHKM::Main(int argc, char *argv[]) } // Record word counts. - if (!options.unknownWordFile.empty()) { - CollectWordLabelCounts(*t, options, wordCount, wordLabel); + if (!options.targetUnknownWordFile.empty()) { + CollectWordLabelCounts(*targetParseTree, options, targetWordCount, targetWordLabel); + } + + // Record word counts: source side. + if (options.sourceLabels && !options.sourceUnknownWordFile.empty()) { + CollectWordLabelCounts(*sourceParseTree, options, sourceWordCount, sourceWordLabel); } // Form an alignment graph from the target tree, source words, and // alignment. - AlignmentGraph graph(t.get(), sourceTokens, alignment); + AlignmentGraph graph(targetParseTree.get(), sourceTokens, alignment); // Extract minimal rules, adding each rule to its root node's rule set. graph.ExtractMinimalRules(options); @@ -172,29 +242,54 @@ int ExtractGHKM::Main(int argc, char *argv[]) const std::vector &rules = (*p)->GetRules(); for (std::vector::const_iterator q = rules.begin(); q != rules.end(); ++q) { - ScfgRule r(**q); + ScfgRule *r = 0; + if (options.sourceLabels) { + r = new ScfgRule(**q, &sourceSyntaxTree); + } else { + r = new ScfgRule(**q); + } // TODO Can scope pruning be done earlier? - if (r.Scope() <= options.maxScope) { + if (r->Scope() <= options.maxScope) { if (!options.treeFragments) { - writer.Write(r); + writer.Write(*r); } else { - writer.Write(r,**q); + writer.Write(*r,**q); } } + delete r; } } } - if (!options.glueGrammarFile.empty()) { - WriteGlueGrammar(labelSet, topLabelSet, glueGrammarStream); + std::map sourceLabels; + if (options.sourceLabels && !options.sourceLabelSetFile.empty()) { + + sourceLabelSet.insert("XLHS"); // non-matching label (left-hand side) + sourceLabelSet.insert("XRHS"); // non-matching label (right-hand side) + sourceLabelSet.insert("TOPLABEL"); // as used in the glue grammar + sourceLabelSet.insert("SOMELABEL"); // as used in the glue grammar + size_t index = 0; + for (std::set::const_iterator iter=sourceLabelSet.begin(); + iter!=sourceLabelSet.end(); ++iter, ++index) { + sourceLabels.insert(std::pair(*iter,index)); + } + WriteSourceLabelSet(sourceLabels, sourceLabelSetStream); } - if (!options.unknownWordFile.empty()) { - WriteUnknownWordLabel(wordCount, wordLabel, options, unknownWordStream); + if (!options.glueGrammarFile.empty()) { + WriteGlueGrammar(targetLabelSet, targetTopLabelSet, sourceLabels, options, glueGrammarStream); + } + + if (!options.targetUnknownWordFile.empty()) { + WriteUnknownWordLabel(targetWordCount, targetWordLabel, options, targetUnknownWordStream); + } + + if (options.sourceLabels && !options.sourceUnknownWordFile.empty()) { + WriteUnknownWordLabel(sourceWordCount, sourceWordLabel, options, sourceUnknownWordStream, true); } if (!options.unknownWordSoftMatchesFile.empty()) { - WriteUnknownWordSoftMatches(labelSet, unknownWordSoftMatchesStream); + WriteUnknownWordSoftMatches(targetLabelSet, unknownWordSoftMatchesStream); } return 0; @@ -305,12 +400,20 @@ void ExtractGHKM::ProcessOptions(int argc, char *argv[], "include score based on PCFG scores in target corpus") ("TreeFragments", "output parse tree information") + ("SourceLabels", + "output source syntax label information") + ("SourceLabelSet", + po::value(&options.sourceLabelSetFile), + "write source syntax label set to named file") ("SentenceOffset", po::value(&options.sentenceOffset)->default_value(options.sentenceOffset), "set sentence number offset if processing split corpus") ("UnknownWordLabel", - po::value(&options.unknownWordFile), + po::value(&options.targetUnknownWordFile), "write unknown word labels to named file") + ("SourceUnknownWordLabel", + po::value(&options.sourceUnknownWordFile), + "write source syntax unknown word labels to named file") ("UnknownWordMinRelFreq", po::value(&options.unknownWordMinRelFreq)->default_value( options.unknownWordMinRelFreq), @@ -402,6 +505,9 @@ void ExtractGHKM::ProcessOptions(int argc, char *argv[], if (vm.count("TreeFragments")) { options.treeFragments = true; } + if (vm.count("SourceLabels")) { + options.sourceLabels = true; + } if (vm.count("UnknownWordUniform")) { options.unknownWordUniform = true; } @@ -411,7 +517,10 @@ void ExtractGHKM::ProcessOptions(int argc, char *argv[], // Workaround for extract-parallel issue. if (options.sentenceOffset > 0) { - options.unknownWordFile.clear(); + options.targetUnknownWordFile.clear(); + } + if (options.sentenceOffset > 0) { + options.sourceUnknownWordFile.clear(); options.unknownWordSoftMatchesFile.clear(); } } @@ -422,7 +531,7 @@ void ExtractGHKM::Error(const std::string &msg) const std::exit(1); } -std::vector ExtractGHKM::ReadTokens(const std::string &s) +std::vector ExtractGHKM::ReadTokens(const std::string &s) const { std::vector tokens; @@ -454,9 +563,11 @@ std::vector ExtractGHKM::ReadTokens(const std::string &s) void ExtractGHKM::WriteGlueGrammar( const std::set &labelSet, const std::map &topLabelSet, + const std::map &sourceLabels, + const Options &options, std::ostream &out) { - // chose a top label that is not already a label + // choose a top label that is not already a label std::string topLabel = "QQQQQQ"; for(size_t i = 1; i <= topLabel.length(); i++) { if (labelSet.find(topLabel.substr(0,i)) == labelSet.end() ) { @@ -465,23 +576,75 @@ void ExtractGHKM::WriteGlueGrammar( } } + std::string sourceTopLabel = "TOPLABEL"; + std::string sourceSLabel = "S"; + std::string sourceSomeLabel = "SOMELABEL"; + // basic rules - out << " [X] ||| [" << topLabel << "] ||| 1 ||| ||| ||| ||| {{Tree [" << topLabel << " ]}}" << std::endl; - out << "[X][" << topLabel << "] [X] ||| [X][" << topLabel << "] [" << topLabel << "] ||| 1 ||| 0-0 ||| ||| ||| {{Tree [" << topLabel << " [" << topLabel << "] ]}}" << std::endl; + out << " [X] ||| [" << topLabel << "] ||| 1 ||| ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " ]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 1 1 " << sourceTopLabel << " 1}}"; + } + out << std::endl; + + out << "[X][" << topLabel << "] [X] ||| [X][" << topLabel << "] [" << topLabel << "] ||| 1 ||| 0-0 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " [" << topLabel << "] ]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 2 1 " << sourceTopLabel << " 1 1 " << sourceTopLabel << " 1}}"; + } + out << std::endl; // top rules for (std::map::const_iterator i = topLabelSet.begin(); i != topLabelSet.end(); ++i) { - out << " [X][" << i->first << "] [X] ||| [X][" << i->first << "] [" << topLabel << "] ||| 1 ||| 1-1 ||| ||| ||| {{Tree [" << topLabel << " [" << i->first << "] ]}}" << std::endl; + out << " [X][" << i->first << "] [X] ||| [X][" << i->first << "] [" << topLabel << "] ||| 1 ||| 1-1 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " [" << i->first << "] ]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 2 1 " << sourceSLabel << " 1 1 " << sourceTopLabel << " 1}}"; + } + out << std::endl; } // glue rules for(std::set::const_iterator i = labelSet.begin(); i != labelSet.end(); i++ ) { - out << "[X][" << topLabel << "] [X][" << *i << "] [X] ||| [X][" << topLabel << "] [X][" << *i << "] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| ||| {{Tree [" << topLabel << " ["<< topLabel << "] [" << *i << "]]}}" << std::endl; + out << "[X][" << topLabel << "] [X][" << *i << "] [X] ||| [X][" << topLabel << "] [X][" << *i << "] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " ["<< topLabel << "] [" << *i << "]]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 3 2.718 " << sourceTopLabel << " " << sourceSomeLabel << " 2.718 1 " << sourceTopLabel << " 2.718}}"; // TODO: there should be better options than using "SOMELABEL" + } + out << std::endl; } + // glue rule for unknown word... - out << "[X][" << topLabel << "] [X][X] [X] ||| [X][" << topLabel << "] [X][X] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| ||| {{Tree [" << topLabel << " [" << topLabel << "] [X]]}}" << std::endl; + out << "[X][" << topLabel << "] [X][X] [X] ||| [X][" << topLabel << "] [X][X] [" << topLabel << "] ||| 2.718 ||| 0-0 1-1 ||| ||| |||"; + if (options.treeFragments) { + out << " {{Tree [" << topLabel << " [" << topLabel << "] [X]]}}"; + } + if (options.sourceLabels) { + out << " {{SourceLabels 3 1 " << sourceTopLabel << " " << sourceSomeLabel << " 1 1 " << sourceTopLabel << " 1}}"; // TODO: there should be better options than using "SOMELABEL" + } + out << std::endl; +} + +void ExtractGHKM::WriteSourceLabelSet( + const std::map &sourceLabels, + std::ostream &out) +{ + out << sourceLabels.size() << std::endl; + for (std::map::const_iterator iter=sourceLabels.begin(); + iter!=sourceLabels.end(); ++iter) { + out << iter->first << " " << iter->second << std::endl; + } } void ExtractGHKM::CollectWordLabelCounts( @@ -513,11 +676,26 @@ void ExtractGHKM::CollectWordLabelCounts( } } +std::vector ExtractGHKM::ReadTokens(const ParseTree &root) const +{ + std::vector tokens; + std::vector leaves; + root.GetLeaves(std::back_inserter(leaves)); + for (std::vector::const_iterator p = leaves.begin(); + p != leaves.end(); ++p) { + const ParseTree &leaf = **p; + const std::string &word = leaf.GetLabel(); + tokens.push_back(word); + } + return tokens; +} + void ExtractGHKM::WriteUnknownWordLabel( const std::map &wordCount, const std::map &wordLabel, const Options &options, - std::ostream &out) + std::ostream &out, + bool writeCounts) { if (!options.unknownWordSoftMatchesFile.empty()) { out << "UNK 1" << std::endl; @@ -537,12 +715,19 @@ void ExtractGHKM::WriteUnknownWordLabel( ++total; } } - for (std::map::const_iterator p = labelCount.begin(); - p != labelCount.end(); ++p) { - double ratio = static_cast(p->second) / static_cast(total); - if (ratio >= options.unknownWordMinRelFreq) { - float weight = options.unknownWordUniform ? 1.0f : ratio; - out << p->first << " " << weight << std::endl; + if ( writeCounts ) { + for (std::map::const_iterator p = labelCount.begin(); + p != labelCount.end(); ++p) { + out << p->first << " " << p->second << std::endl; + } + } else { + for (std::map::const_iterator p = labelCount.begin(); + p != labelCount.end(); ++p) { + double ratio = static_cast(p->second) / static_cast(total); + if (ratio >= options.unknownWordMinRelFreq) { + float weight = options.unknownWordUniform ? 1.0f : ratio; + out << p->first << " " << weight << std::endl; + } } } } diff --git a/phrase-extract/extract-ghkm/ExtractGHKM.h b/phrase-extract/extract-ghkm/ExtractGHKM.h index 4c78923d3..44ce9fdbd 100644 --- a/phrase-extract/extract-ghkm/ExtractGHKM.h +++ b/phrase-extract/extract-ghkm/ExtractGHKM.h @@ -59,13 +59,19 @@ private: void WriteUnknownWordLabel(const std::map &, const std::map &, const Options &, - std::ostream &); + std::ostream &, + bool writeCounts=false); void WriteUnknownWordSoftMatches(const std::set &, std::ostream &); void WriteGlueGrammar(const std::set &, const std::map &, + const std::map &, + const Options &, std::ostream &); - std::vector ReadTokens(const std::string &); + void WriteSourceLabelSet(const std::map &, + std::ostream &); + std::vector ReadTokens(const std::string &) const; + std::vector ReadTokens(const ParseTree &root) const; void ProcessOptions(int, char *[], Options &) const; diff --git a/phrase-extract/extract-ghkm/Options.h b/phrase-extract/extract-ghkm/Options.h index ffa9bfa35..28a581802 100644 --- a/phrase-extract/extract-ghkm/Options.h +++ b/phrase-extract/extract-ghkm/Options.h @@ -41,6 +41,7 @@ public: , minimal(false) , pcfg(false) , treeFragments(false) + , sourceLabels(false) , sentenceOffset(0) , unpairedExtractFormat(false) , unknownWordMinRelFreq(0.03f) @@ -64,9 +65,12 @@ public: bool minimal; bool pcfg; bool treeFragments; + bool sourceLabels; + std::string sourceLabelSetFile; int sentenceOffset; bool unpairedExtractFormat; - std::string unknownWordFile; + std::string targetUnknownWordFile; + std::string sourceUnknownWordFile; std::string unknownWordSoftMatchesFile; float unknownWordMinRelFreq; bool unknownWordUniform; diff --git a/phrase-extract/extract-ghkm/ParseTree.h b/phrase-extract/extract-ghkm/ParseTree.h index 03da17735..694286c9d 100644 --- a/phrase-extract/extract-ghkm/ParseTree.h +++ b/phrase-extract/extract-ghkm/ParseTree.h @@ -63,7 +63,7 @@ public: bool IsLeaf() const; template - void GetLeaves(OutputIterator); + void GetLeaves(OutputIterator) const; private: // Disallow copying @@ -77,7 +77,7 @@ private: }; template -void ParseTree::GetLeaves(OutputIterator result) +void ParseTree::GetLeaves(OutputIterator result) const { if (IsLeaf()) { *result++ = this; diff --git a/phrase-extract/extract-ghkm/ScfgRule.cpp b/phrase-extract/extract-ghkm/ScfgRule.cpp index 2c901413d..a4dd91e0e 100644 --- a/phrase-extract/extract-ghkm/ScfgRule.cpp +++ b/phrase-extract/extract-ghkm/ScfgRule.cpp @@ -21,6 +21,7 @@ #include "Node.h" #include "Subgraph.h" +#include "SyntaxTree.h" #include @@ -29,11 +30,14 @@ namespace Moses namespace GHKM { -ScfgRule::ScfgRule(const Subgraph &fragment) +ScfgRule::ScfgRule(const Subgraph &fragment, + const MosesTraining::SyntaxTree *sourceSyntaxTree) : m_sourceLHS("X", NonTerminal) , m_targetLHS(fragment.GetRoot()->GetLabel(), NonTerminal) , m_pcfgScore(fragment.GetPcfgScore()) + , m_hasSourceLabels(sourceSyntaxTree) { + // Source RHS const std::set &leaves = fragment.GetLeaves(); @@ -55,6 +59,7 @@ ScfgRule::ScfgRule(const Subgraph &fragment) std::map > sourceOrder; m_sourceRHS.reserve(sourceRHSNodes.size()); + m_numberOfNonTerminals = 0; int srcIndex = 0; for (std::vector::const_iterator p(sourceRHSNodes.begin()); p != sourceRHSNodes.end(); ++p, ++srcIndex) { @@ -62,6 +67,11 @@ ScfgRule::ScfgRule(const Subgraph &fragment) if (sinkNode.GetType() == TREE) { m_sourceRHS.push_back(Symbol("X", NonTerminal)); sourceOrder[&sinkNode].push_back(srcIndex); + ++m_numberOfNonTerminals; + if (sourceSyntaxTree) { + // Source syntax label + PushSourceLabel(sourceSyntaxTree,&sinkNode,"XRHS"); + } } else { assert(sinkNode.GetType() == SOURCE); m_sourceRHS.push_back(Symbol(sinkNode.GetLabel(), Terminal)); @@ -112,6 +122,76 @@ ScfgRule::ScfgRule(const Subgraph &fragment) } } } + + if (sourceSyntaxTree) { + // Source syntax label for root node (if sourceSyntaxTree available) + PushSourceLabel(sourceSyntaxTree,fragment.GetRoot(),"XLHS"); + // All non-terminal spans (including the LHS) should have obtained a label + // (a source-side syntactic constituent label if the span matches, "XLHS" otherwise) + assert(m_sourceLabels.size() == m_numberOfNonTerminals+1); + } +} + +void ScfgRule::PushSourceLabel(const MosesTraining::SyntaxTree *sourceSyntaxTree, + const Node *node, + const std::string &nonMatchingLabel) +{ + ContiguousSpan span = Closure(node->GetSpan()); + if (sourceSyntaxTree->HasNode(span.first,span.second)) { // does a source constituent match the span? + std::vector sourceLabels = + sourceSyntaxTree->GetNodes(span.first,span.second); + if (!sourceLabels.empty()) { + // store the topmost matching label from the source syntax tree + m_sourceLabels.push_back(sourceLabels.back()->GetLabel()); + } + } else { + // no matching source-side syntactic constituent: store nonMatchingLabel + m_sourceLabels.push_back(nonMatchingLabel); + } +} + +// TODO: rather implement the method external to ScfgRule +void ScfgRule::UpdateSourceLabelCoocCounts(std::map< std::string, std::map* > &coocCounts, float count) const +{ + std::map sourceToTargetNTMap; + std::map targetToSourceNTMap; + + for (Alignment::const_iterator p(m_alignment.begin()); + p != m_alignment.end(); ++p) { + if ( m_sourceRHS[p->first].GetType() == NonTerminal ) { + assert(m_targetRHS[p->second].GetType() == NonTerminal); + sourceToTargetNTMap[p->first] = p->second; + } + } + + size_t sourceIndex = 0; + size_t sourceNonTerminalIndex = 0; + for (std::vector::const_iterator p=m_sourceRHS.begin(); + p != m_sourceRHS.end(); ++p, ++sourceIndex) { + if ( p->GetType() == NonTerminal ) { + const std::string &sourceLabel = m_sourceLabels[sourceNonTerminalIndex]; + int targetIndex = sourceToTargetNTMap[sourceIndex]; + const std::string &targetLabel = m_targetRHS[targetIndex].GetValue(); + ++sourceNonTerminalIndex; + + std::map* countMap = NULL; + std::map< std::string, std::map* >::iterator iter = coocCounts.find(sourceLabel); + if ( iter == coocCounts.end() ) { + std::map *newCountMap = new std::map(); + std::pair< std::map< std::string, std::map* >::iterator, bool > inserted = + coocCounts.insert( std::pair< std::string, std::map* >(sourceLabel, newCountMap) ); + assert(inserted.second); + countMap = (inserted.first)->second; + } else { + countMap = iter->second; + } + std::pair< std::map::iterator, bool > inserted = + countMap->insert( std::pair< std::string,float>(targetLabel, count) ); + if ( !inserted.second ) { + (inserted.first)->second += count; + } + } + } } int ScfgRule::Scope() const diff --git a/phrase-extract/extract-ghkm/ScfgRule.h b/phrase-extract/extract-ghkm/ScfgRule.h index 21a9e9900..5f1f35a61 100644 --- a/phrase-extract/extract-ghkm/ScfgRule.h +++ b/phrase-extract/extract-ghkm/ScfgRule.h @@ -22,9 +22,13 @@ #define EXTRACT_GHKM_SCFG_RULE_H_ #include "Alignment.h" +#include "SyntaxTree.h" #include #include +#include +#include +#include namespace Moses { @@ -55,7 +59,8 @@ private: class ScfgRule { public: - ScfgRule(const Subgraph &fragment); + ScfgRule(const Subgraph &fragment, + const MosesTraining::SyntaxTree *sourceSyntaxTree = 0); const Symbol &GetSourceLHS() const { return m_sourceLHS; @@ -75,18 +80,36 @@ public: float GetPcfgScore() const { return m_pcfgScore; } + bool HasSourceLabels() const { + return m_hasSourceLabels; + } + void PrintSourceLabels(std::ostream &out) const { + for (std::vector::const_iterator it = m_sourceLabels.begin(); + it != m_sourceLabels.end(); ++it) { + out << " " << (*it); + } + } + void UpdateSourceLabelCoocCounts(std::map< std::string, std::map* > &coocCounts, + float count) const; int Scope() const; private: static bool PartitionOrderComp(const Node *, const Node *); + void PushSourceLabel(const MosesTraining::SyntaxTree *sourceSyntaxTree, + const Node *node, + const std::string &nonMatchingLabel); + Symbol m_sourceLHS; Symbol m_targetLHS; std::vector m_sourceRHS; std::vector m_targetRHS; Alignment m_alignment; float m_pcfgScore; + bool m_hasSourceLabels; + std::vector m_sourceLabels; + unsigned m_numberOfNonTerminals; }; } // namespace GHKM diff --git a/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp b/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp index bc8fd7233..be373b67b 100644 --- a/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp +++ b/phrase-extract/extract-ghkm/ScfgRuleWriter.cpp @@ -66,6 +66,12 @@ void ScfgRuleWriter::Write(const ScfgRule &rule, bool printEndl) m_fwd << " ||| " << std::exp(rule.GetPcfgScore()); } + if (m_options.sourceLabels && rule.HasSourceLabels()) { + m_fwd << " {{SourceLabels"; + rule.PrintSourceLabels(m_fwd); + m_fwd << "}}"; + } + if (printEndl) { m_fwd << std::endl; m_inv << std::endl; diff --git a/phrase-extract/extract-ghkm/XmlTreeParser.h b/phrase-extract/extract-ghkm/XmlTreeParser.h index d00fd7d9f..e5bf5b463 100644 --- a/phrase-extract/extract-ghkm/XmlTreeParser.h +++ b/phrase-extract/extract-ghkm/XmlTreeParser.h @@ -45,9 +45,11 @@ class XmlTreeParser public: XmlTreeParser(std::set &, std::map &); std::auto_ptr Parse(const std::string &); + + static std::auto_ptr ConvertTree(const MosesTraining::SyntaxNode &, + const std::vector &); + private: - std::auto_ptr ConvertTree(const MosesTraining::SyntaxNode &, - const std::vector &); std::set &m_labelSet; std::map &m_topLabelSet; diff --git a/phrase-extract/extract-main.cpp b/phrase-extract/extract-main.cpp index 5d58028d6..fe3d99cd2 100644 --- a/phrase-extract/extract-main.cpp +++ b/phrase-extract/extract-main.cpp @@ -19,7 +19,6 @@ #include #include -#include "SafeGetline.h" #include "SentenceAlignment.h" #include "tables-core.h" #include "InputFileStream.h" @@ -32,10 +31,6 @@ using namespace MosesTraining; namespace MosesTraining { - -const long int LINE_MAX_LENGTH = 500000 ; - - // HPhraseVertex represents a point in the alignment matrix typedef pair HPhraseVertex; @@ -277,20 +272,18 @@ int main(int argc, char* argv[]) int i = sentenceOffset; - while(true) { + string englishString, foreignString, alignmentString, weightString; + + while(getline(*eFileP, englishString)) { i++; if (i%10000 == 0) cerr << "." << flush; - char englishString[LINE_MAX_LENGTH]; - char foreignString[LINE_MAX_LENGTH]; - char alignmentString[LINE_MAX_LENGTH]; - char weightString[LINE_MAX_LENGTH]; - SAFE_GETLINE((*eFileP), englishString, LINE_MAX_LENGTH, '\n', __FILE__); - if (eFileP->eof()) break; - SAFE_GETLINE((*fFileP), foreignString, LINE_MAX_LENGTH, '\n', __FILE__); - SAFE_GETLINE((*aFileP), alignmentString, LINE_MAX_LENGTH, '\n', __FILE__); + + getline(*fFileP, foreignString); + getline(*aFileP, alignmentString); if (iwFileP) { - SAFE_GETLINE((*iwFileP), weightString, LINE_MAX_LENGTH, '\n', __FILE__); + getline(*iwFileP, weightString); } + SentenceAlignment sentence; // cout << "read in: " << englishString << " & " << foreignString << " & " << alignmentString << endl; //az: output src, tgt, and alingment line @@ -300,7 +293,11 @@ int main(int argc, char* argv[]) cout << "LOG: ALT: " << alignmentString << endl; cout << "LOG: PHRASES_BEGIN:" << endl; } - if (sentence.create( englishString, foreignString, alignmentString, weightString, i, false)) { + if (sentence.create( englishString.c_str(), + foreignString.c_str(), + alignmentString.c_str(), + weightString.c_str(), + i, false)) { if (options.placeholders.size()) { sentence.invertAlignment(); } diff --git a/phrase-extract/extract-ordering-main.cpp b/phrase-extract/extract-ordering-main.cpp index 104457b01..b418ba24d 100644 --- a/phrase-extract/extract-ordering-main.cpp +++ b/phrase-extract/extract-ordering-main.cpp @@ -19,7 +19,6 @@ #include #include -#include "SafeGetline.h" #include "SentenceAlignment.h" #include "tables-core.h" #include "InputFileStream.h" @@ -32,10 +31,6 @@ using namespace MosesTraining; namespace MosesTraining { - -const long int LINE_MAX_LENGTH = 500000 ; - - // HPhraseVertex represents a point in the alignment matrix typedef pair HPhraseVertex; @@ -246,20 +241,20 @@ int main(int argc, char* argv[]) int i = sentenceOffset; - while(true) { + string englishString, foreignString, alignmentString, weightString; + + while(getline(*eFileP, englishString)) { i++; - if (i%10000 == 0) cerr << "." << flush; - char englishString[LINE_MAX_LENGTH]; - char foreignString[LINE_MAX_LENGTH]; - char alignmentString[LINE_MAX_LENGTH]; - char weightString[LINE_MAX_LENGTH]; - SAFE_GETLINE((*eFileP), englishString, LINE_MAX_LENGTH, '\n', __FILE__); - if (eFileP->eof()) break; - SAFE_GETLINE((*fFileP), foreignString, LINE_MAX_LENGTH, '\n', __FILE__); - SAFE_GETLINE((*aFileP), alignmentString, LINE_MAX_LENGTH, '\n', __FILE__); + + getline(*eFileP, englishString); + getline(*fFileP, foreignString); + getline(*aFileP, alignmentString); if (iwFileP) { - SAFE_GETLINE((*iwFileP), weightString, LINE_MAX_LENGTH, '\n', __FILE__); + getline(*iwFileP, weightString); } + + if (i%10000 == 0) cerr << "." << flush; + SentenceAlignment sentence; // cout << "read in: " << englishString << " & " << foreignString << " & " << alignmentString << endl; //az: output src, tgt, and alingment line @@ -269,7 +264,7 @@ int main(int argc, char* argv[]) cout << "LOG: ALT: " << alignmentString << endl; cout << "LOG: PHRASES_BEGIN:" << endl; } - if (sentence.create( englishString, foreignString, alignmentString, weightString, i, false)) { + if (sentence.create( englishString.c_str(), foreignString.c_str(), alignmentString.c_str(), weightString.c_str(), i, false)) { ExtractTask *task = new ExtractTask(i-1, sentence, options, extractFileOrientation); task->Run(); delete task; diff --git a/phrase-extract/extract-rules-main.cpp b/phrase-extract/extract-rules-main.cpp index f5f44316e..592946b0d 100644 --- a/phrase-extract/extract-rules-main.cpp +++ b/phrase-extract/extract-rules-main.cpp @@ -39,7 +39,6 @@ #include "Hole.h" #include "HoleCollection.h" #include "RuleExist.h" -#include "SafeGetline.h" #include "SentenceAlignmentWithSyntax.h" #include "SyntaxTree.h" #include "tables-core.h" @@ -47,8 +46,6 @@ #include "InputFileStream.h" #include "OutputFileStream.h" -#define LINE_MAX_LENGTH 500000 - using namespace std; using namespace MosesTraining; @@ -326,17 +323,15 @@ int main(int argc, char* argv[]) // loop through all sentence pairs size_t i=sentenceOffset; - while(true) { - i++; - if (i%1000 == 0) cerr << i << " " << flush; + string targetString, sourceString, alignmentString; - char targetString[LINE_MAX_LENGTH]; - char sourceString[LINE_MAX_LENGTH]; - char alignmentString[LINE_MAX_LENGTH]; - SAFE_GETLINE((*tFileP), targetString, LINE_MAX_LENGTH, '\n', __FILE__); - if (tFileP->eof()) break; - SAFE_GETLINE((*sFileP), sourceString, LINE_MAX_LENGTH, '\n', __FILE__); - SAFE_GETLINE((*aFileP), alignmentString, LINE_MAX_LENGTH, '\n', __FILE__); + while(getline(*tFileP, targetString)) { + i++; + + getline(*sFileP, sourceString); + getline(*aFileP, alignmentString); + + if (i%1000 == 0) cerr << i << " " << flush; SentenceAlignmentWithSyntax sentence (targetLabelCollection, sourceLabelCollection, @@ -349,7 +344,7 @@ int main(int argc, char* argv[]) cout << "LOG: PHRASES_BEGIN:" << endl; } - if (sentence.create(targetString, sourceString, alignmentString,"", i, options.boundaryRules)) { + if (sentence.create(targetString.c_str(), sourceString.c_str(), alignmentString.c_str(),"", i, options.boundaryRules)) { if (options.unknownWordLabelFlag) { collectWordLabelCounts(sentence); } diff --git a/phrase-extract/relax-parse-main.cpp b/phrase-extract/relax-parse-main.cpp index a58d4d97f..e5feb94d0 100644 --- a/phrase-extract/relax-parse-main.cpp +++ b/phrase-extract/relax-parse-main.cpp @@ -20,8 +20,6 @@ ***********************************************************************/ #include "relax-parse.h" - -#include "SafeGetline.h" #include "tables-core.h" using namespace std; @@ -33,17 +31,13 @@ int main(int argc, char* argv[]) // loop through all sentences int i=0; - char inBuffer[LINE_MAX_LENGTH]; - while(true) { + string inBuffer; + while(getline(cin, inBuffer)) { i++; if (i%1000 == 0) cerr << "." << flush; if (i%10000 == 0) cerr << ":" << flush; if (i%100000 == 0) cerr << "!" << flush; - // get line from stdin - SAFE_GETLINE( cin, inBuffer, LINE_MAX_LENGTH, '\n', __FILE__); - if (cin.eof()) break; - // process into syntax tree representation string inBufferString = string( inBuffer ); set< string > labelCollection; // set of labels, not used diff --git a/phrase-extract/score-main.cpp b/phrase-extract/score-main.cpp index 46538010f..e8ba1d942 100644 --- a/phrase-extract/score-main.cpp +++ b/phrase-extract/score-main.cpp @@ -28,8 +28,8 @@ #include #include #include +#include -#include "SafeGetline.h" #include "ScoreFeature.h" #include "tables-core.h" #include "ExtractionPhrasePair.h" @@ -40,8 +40,6 @@ using namespace std; using namespace MosesTraining; -#define LINE_MAX_LENGTH 100000 - namespace MosesTraining { LexicalTable lexTable; @@ -49,6 +47,10 @@ bool inverseFlag = false; bool hierarchicalFlag = false; bool pcfgFlag = false; bool treeFragmentsFlag = false; +bool sourceSyntaxLabelsFlag = false; +bool sourceSyntaxLabelSetFlag = false; +bool sourceSyntaxLabelCountsLHSFlag = false; +bool targetPreferenceLabelsFlag = false; bool unpairedExtractFormatFlag = false; bool conditionOnTargetLhsFlag = false; bool wordAlignmentFlag = true; @@ -61,16 +63,25 @@ bool lexFlag = true; bool unalignedFlag = false; bool unalignedFWFlag = false; bool crossedNonTerm = false; +bool spanLength = false; +bool nonTermContext = false; + int countOfCounts[COC_MAX+1]; int totalDistinct = 0; float minCountHierarchical = 0; -std::map sourceLHSCounts; -std::map* > targetLHSAndSourceLHSJointCounts; +boost::unordered_map sourceLHSCounts; +boost::unordered_map* > targetLHSAndSourceLHSJointCounts; std::set sourceLabelSet; std::map sourceLabels; std::vector sourceLabelsByIndex; +boost::unordered_map targetPreferenceLHSCounts; +boost::unordered_map* > ruleTargetLHSAndTargetPreferenceLHSJointCounts; +std::set targetPreferenceLabelSet; +std::map targetPreferenceLabels; +std::vector targetPreferenceLabelsByIndex; + Vocabulary vcbT; Vocabulary vcbS; @@ -84,6 +95,11 @@ void processLine( std::string line, std::string &additionalPropertiesString, float &count, float &pcfgSum ); void writeCountOfCounts( const std::string &fileNameCountOfCounts ); +void writeLeftHandSideLabelCounts( const boost::unordered_map &countsLabelLHS, + const boost::unordered_map* > &jointCountsLabelLHS, + const std::string &fileNameLeftHandSideSourceLabelCounts, + const std::string &fileNameLeftHandSideTargetSourceLabelCounts ); +void writeLabelSet( const std::set &labelSet, const std::string &fileName ); void processPhrasePairs( std::vector< ExtractionPhrasePair* > &phrasePairsWithSameSource, ostream &phraseTableFile, const ScoreFeatureManager& featureManager, const MaybeLog& maybeLogProb ); void outputPhrasePair(const ExtractionPhrasePair &phrasePair, float, int, ostream &phraseTableFile, const ScoreFeatureManager &featureManager, const MaybeLog &maybeLog ); @@ -105,15 +121,21 @@ int main(int argc, char* argv[]) ScoreFeatureManager featureManager; if (argc < 4) { - std::cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--NoWordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--PCFG] [--TreeFragments] [--UnpairedExtractFormat] [--ConditionOnTargetLHS] [--CrossedNonTerm]" << std::endl; + std::cerr << "syntax: score extract lex phrase-table [--Inverse] [--Hierarchical] [--LogProb] [--NegLogProb] [--NoLex] [--GoodTuring] [--KneserNey] [--NoWordAlignment] [--UnalignedPenalty] [--UnalignedFunctionWordPenalty function-word-file] [--MinCountHierarchical count] [--PCFG] [--TreeFragments] [--SourceLabels] [--SourceLabelSet] [--SourceLabelCountsLHS] [--TargetPreferenceLabels] [--UnpairedExtractFormat] [--ConditionOnTargetLHS] [--CrossedNonTerm]" << std::endl; std::cerr << featureManager.usage() << std::endl; exit(1); } std::string fileNameExtract = argv[1]; std::string fileNameLex = argv[2]; std::string fileNamePhraseTable = argv[3]; + std::string fileNameSourceLabelSet; std::string fileNameCountOfCounts; std::string fileNameFunctionWords; + std::string fileNameLeftHandSideSourceLabelCounts; + std::string fileNameLeftHandSideTargetSourceLabelCounts; + std::string fileNameTargetPreferenceLabelSet; + std::string fileNameLeftHandSideTargetPreferenceLabelCounts; + std::string fileNameLeftHandSideRuleTargetTargetPreferenceLabelCounts; std::vector featureArgs; // all unknown args passed to feature manager for(int i=4; i phrasePairsWithSameSource; @@ -245,8 +293,8 @@ int main(int argc, char* argv[]) float tmpCount=0.0f, tmpPcfgSum=0.0f; int i=0; - SAFE_GETLINE( (extractFileP), line, LINE_MAX_LENGTH, '\n', __FILE__ ); - if ( !extractFileP.eof() ) { + // TODO why read only the 1st line? + if ( getline(extractFileP, line) ) { ++i; tmpPhraseSource = new PHRASE(); tmpPhraseTarget = new PHRASE(); @@ -265,23 +313,21 @@ int main(int argc, char* argv[]) if ( hierarchicalFlag ) { phrasePairsWithSameSourceAndTarget.push_back( phrasePair ); } - strcpy( lastLine, line ); - SAFE_GETLINE( (extractFileP), line, LINE_MAX_LENGTH, '\n', __FILE__ ); + lastLine = line; } - while ( !extractFileP.eof() ) { + while ( getline(extractFileP, line) ) { if ( ++i % 100000 == 0 ) { std::cerr << "." << std::flush; } // identical to last line? just add count - if (strcmp(line,lastLine) == 0) { + if (line == lastLine) { phrasePair->IncrementPrevious(tmpCount,tmpPcfgSum); - SAFE_GETLINE((extractFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); continue; } else { - strcpy( lastLine, line ); + lastLine = line; } tmpPhraseSource = new PHRASE(); @@ -359,8 +405,6 @@ int main(int argc, char* argv[]) } } - SAFE_GETLINE((extractFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - } processPhrasePairs( phrasePairsWithSameSource, *phraseTableFile, featureManager, maybeLogProb ); @@ -380,6 +424,26 @@ int main(int argc, char* argv[]) if (goodTuringFlag || kneserNeyFlag) { writeCountOfCounts( fileNameCountOfCounts ); } + + // source syntax labels + if (sourceSyntaxLabelsFlag && sourceSyntaxLabelSetFlag && !inverseFlag) { + writeLabelSet( sourceLabelSet, fileNameSourceLabelSet ); + } + if (sourceSyntaxLabelsFlag && sourceSyntaxLabelCountsLHSFlag && !inverseFlag) { + writeLeftHandSideLabelCounts( sourceLHSCounts, + targetLHSAndSourceLHSJointCounts, + fileNameLeftHandSideSourceLabelCounts, + fileNameLeftHandSideTargetSourceLabelCounts ); + } + + // target preference labels + if (targetPreferenceLabelsFlag && !inverseFlag) { + writeLabelSet( targetPreferenceLabelSet, fileNameTargetPreferenceLabelSet ); + writeLeftHandSideLabelCounts( targetPreferenceLHSCounts, + ruleTargetLHSAndTargetPreferenceLHSJointCounts, + fileNameLeftHandSideTargetPreferenceLabelCounts, + fileNameLeftHandSideRuleTargetTargetPreferenceLabelCounts ); + } } @@ -474,6 +538,70 @@ void writeCountOfCounts( const string &fileNameCountOfCounts ) } +void writeLeftHandSideLabelCounts( const boost::unordered_map &countsLabelLHS, + const boost::unordered_map* > &jointCountsLabelLHS, + const std::string &fileNameLeftHandSideSourceLabelCounts, + const std::string &fileNameLeftHandSideTargetSourceLabelCounts ) +{ + // open file + Moses::OutputFileStream leftHandSideSourceLabelCounts; + bool success = leftHandSideSourceLabelCounts.Open(fileNameLeftHandSideSourceLabelCounts.c_str()); + if (!success) { + std::cerr << "ERROR: could not open left-hand side label counts file " + << fileNameLeftHandSideSourceLabelCounts << std::endl; + return; + } + + // write source left-hand side counts + for (boost::unordered_map::const_iterator iter=sourceLHSCounts.begin(); + iter!=sourceLHSCounts.end(); ++iter) { + leftHandSideSourceLabelCounts << iter->first << " " << iter->second << std::endl; + } + + leftHandSideSourceLabelCounts.Close(); + + // open file + Moses::OutputFileStream leftHandSideTargetSourceLabelCounts; + success = leftHandSideTargetSourceLabelCounts.Open(fileNameLeftHandSideTargetSourceLabelCounts.c_str()); + if (!success) { + std::cerr << "ERROR: could not open left-hand side label joint counts file " + << fileNameLeftHandSideTargetSourceLabelCounts << std::endl; + return; + } + + // write source left-hand side / target left-hand side joint counts + for (boost::unordered_map* >::const_iterator iter=targetLHSAndSourceLHSJointCounts.begin(); + iter!=targetLHSAndSourceLHSJointCounts.end(); ++iter) { + for (boost::unordered_map::const_iterator iter2=(iter->second)->begin(); + iter2!=(iter->second)->end(); ++iter2) { + leftHandSideTargetSourceLabelCounts << iter->first << " "<< iter2->first << " " << iter2->second << std::endl; + } + } + + leftHandSideTargetSourceLabelCounts.Close(); +} + + +void writeLabelSet( const std::set &labelSet, const std::string &fileName ) +{ + // open file + Moses::OutputFileStream out; + bool success = out.Open(fileName.c_str()); + if (!success) { + std::cerr << "ERROR: could not open label set file " + << fileName << std::endl; + return; + } + + for (std::set::const_iterator iter=labelSet.begin(); + iter!=labelSet.end(); ++iter) { + out << *iter << std::endl; + } + + out.Close(); +} + + void processPhrasePairs( std::vector< ExtractionPhrasePair* > &phrasePairsWithSameSource, ostream &phraseTableFile, const ScoreFeatureManager& featureManager, const MaybeLog& maybeLogProb ) { @@ -646,7 +774,7 @@ void outputPhrasePair(const ExtractionPhrasePair &phrasePair, if (kneserNeyFlag) phraseTableFile << " " << distinctCount; - if ((treeFragmentsFlag) && + if ((treeFragmentsFlag || sourceSyntaxLabelsFlag || targetPreferenceLabelsFlag) && !inverseFlag) { phraseTableFile << " |||"; } @@ -661,6 +789,63 @@ void outputPhrasePair(const ExtractionPhrasePair &phrasePair, } } + // syntax labels + if ((sourceSyntaxLabelsFlag || targetPreferenceLabelsFlag) && !inverseFlag) { + unsigned nNTs = 1; + for(size_t j=0; jsize()-1; ++j) { + if (isNonTerminal(vcbS.getWord( phraseSource->at(j) ))) + ++nNTs; + } + // source syntax labels + if (sourceSyntaxLabelsFlag) { + std::string sourceLabelCounts; + sourceLabelCounts = phrasePair.CollectAllLabelsSeparateLHSAndRHS("SourceLabels", + sourceLabelSet, + sourceLHSCounts, + targetLHSAndSourceLHSJointCounts, + vcbT); + if ( !sourceLabelCounts.empty() ) { + phraseTableFile << " {{SourceLabels " + << nNTs // for convenience: number of non-terminal symbols in this rule (incl. left hand side NT) + << " " + << count // rule count + << sourceLabelCounts + << "}}"; + } + } + // target preference labels + if (targetPreferenceLabelsFlag) { + std::string targetPreferenceLabelCounts; + targetPreferenceLabelCounts = phrasePair.CollectAllLabelsSeparateLHSAndRHS("TargetPreferences", + targetPreferenceLabelSet, + targetPreferenceLHSCounts, + ruleTargetLHSAndTargetPreferenceLHSJointCounts, + vcbT); + if ( !targetPreferenceLabelCounts.empty() ) { + phraseTableFile << " {{TargetPreferences " + << nNTs // for convenience: number of non-terminal symbols in this rule (incl. left hand side NT) + << " " + << count // rule count + << targetPreferenceLabelCounts + << "}}"; + } + } + } + + if (spanLength && !inverseFlag) { + string propValue = phrasePair.CollectAllPropertyValues("SpanLength"); + if (!propValue.empty()) { + phraseTableFile << " {{SpanLength " << propValue << "}}"; + } + } + + if (nonTermContext && !inverseFlag) { + string propValue = phrasePair.CollectAllPropertyValues("NonTermContext"); + if (!propValue.empty()) { + phraseTableFile << " {{NonTermContext " << propValue << "}}"; + } + } + phraseTableFile << std::endl; } @@ -750,11 +935,9 @@ void loadFunctionWords( const string &fileName ) } istream *inFileP = &inFile; - char line[LINE_MAX_LENGTH]; - while(true) { - SAFE_GETLINE((*inFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (inFileP->eof()) break; - std::vector token = tokenize( line ); + string line; + while(getline(*inFileP, line)) { + std::vector token = tokenize( line.c_str() ); if (token.size() > 0) functionWordList.insert( token[0] ); } @@ -799,16 +982,13 @@ void LexicalTable::load( const string &fileName ) } istream *inFileP = &inFile; - char line[LINE_MAX_LENGTH]; - + string line; int i=0; - while(true) { + while(getline(*inFileP, line)) { i++; if (i%100000 == 0) std::cerr << "." << flush; - SAFE_GETLINE((*inFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (inFileP->eof()) break; - std::vector token = tokenize( line ); + std::vector token = tokenize( line.c_str() ); if (token.size() != 3) { std::cerr << "line " << i << " in " << fileName << " has wrong number of tokens, skipping:" << std::endl @@ -906,3 +1086,4 @@ void invertAlignment(const PHRASE *phraseSource, const PHRASE *phraseTarget, } } } + diff --git a/phrase-extract/score.h b/phrase-extract/score.h index 6a10536c1..470332a06 100644 --- a/phrase-extract/score.h +++ b/phrase-extract/score.h @@ -1,12 +1,22 @@ -#pragma once -/* - * score.h - * extract - * - * Created by Hieu Hoang on 28/07/2010. - * Copyright 2010 __MyCompanyName__. All rights reserved. - * - */ +/*********************************************************************** + Moses - factored phrase-based language decoder + Copyright (C) 2009 University of Edinburgh + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + ***********************************************************************/ + #include #include diff --git a/phrase-extract/statistics-main.cpp b/phrase-extract/statistics-main.cpp index 67373ec93..9d814ed76 100644 --- a/phrase-extract/statistics-main.cpp +++ b/phrase-extract/statistics-main.cpp @@ -12,15 +12,12 @@ #include #include "AlignmentPhrase.h" -#include "SafeGetline.h" #include "tables-core.h" #include "InputFileStream.h" using namespace std; using namespace MosesTraining; -#define LINE_MAX_LENGTH 10000 - namespace MosesTraining { @@ -31,7 +28,7 @@ public: vector< vector > alignedToE; vector< vector > alignedToF; - bool create( char*, int ); + bool create( const char*, int ); void clear(); bool equals( const PhraseAlignment& ); }; @@ -106,16 +103,14 @@ int main(int argc, char* argv[]) vector< PhraseAlignment > phrasePairsWithSameF; int i=0; int fileCount = 0; - while(true) { + + string line; + while(getline(extractFileP, line)) { if (extractFileP.eof()) break; if (++i % 100000 == 0) cerr << "." << flush; - char line[LINE_MAX_LENGTH]; - SAFE_GETLINE((extractFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - // if (fileCount>0) - if (extractFileP.eof()) - break; + PhraseAlignment phrasePair; - bool isPhrasePair = phrasePair.create( line, i ); + bool isPhrasePair = phrasePair.create( line.c_str(), i ); if (lastForeign >= 0 && lastForeign != phrasePair.foreign) { processPhrasePairs( phrasePairsWithSameF ); for(size_t j=0; j &phrasePair ) } } -bool PhraseAlignment::create( char line[], int lineID ) +bool PhraseAlignment::create(const char line[], int lineID ) { vector< string > token = tokenize( line ); int item = 1; @@ -321,16 +316,14 @@ void LexicalTable::load( const string &filePath ) } istream *inFileP = &inFile; - char line[LINE_MAX_LENGTH]; + string line; int i=0; - while(true) { + while(getline(*inFileP, line)) { i++; if (i%100000 == 0) cerr << "." << flush; - SAFE_GETLINE((*inFileP), line, LINE_MAX_LENGTH, '\n', __FILE__); - if (inFileP->eof()) break; - vector token = tokenize( line ); + vector token = tokenize( line.c_str() ); if (token.size() != 3) { cerr << "line " << i << " in " << filePath << " has wrong number of tokens, skipping:\n" << token.size() << " " << token[0] << " " << line << endl; diff --git a/phrase-extract/tables-core.h b/phrase-extract/tables-core.h index e239e5900..9662ced2a 100644 --- a/phrase-extract/tables-core.h +++ b/phrase-extract/tables-core.h @@ -27,7 +27,7 @@ public: std::vector< WORD > vocab; WORD_ID storeIfNew( const WORD& ); WORD_ID getWordID( const WORD& ); - inline WORD &getWord( WORD_ID id ) { + inline WORD &getWord( const WORD_ID id ) { return vocab[ id ]; } }; diff --git a/scripts/OSM/OSM-Train.perl b/scripts/OSM/OSM-Train.perl index 44c43796d..2d2427bc5 100755 --- a/scripts/OSM/OSM-Train.perl +++ b/scripts/OSM/OSM-Train.perl @@ -9,7 +9,7 @@ print STDERR "Training OSM - Start\n".`date`; my $ORDER = 5; my $OUT_DIR = "/tmp/osm.$$"; my $___FACTOR_DELIMITER = "|"; -my ($MOSES_SRC_DIR,$CORPUS_F,$CORPUS_E,$ALIGNMENT,$SRILM_DIR,$FACTOR); +my ($MOSES_SRC_DIR,$CORPUS_F,$CORPUS_E,$ALIGNMENT,$SRILM_DIR,$FACTOR,$LMPLZ); # utilities my $ZCAT = "gzip -cd"; @@ -23,15 +23,16 @@ die("ERROR: wrong syntax when invoking OSM-Train.perl") 'order=i' => \$ORDER, 'factor=s' => \$FACTOR, 'srilm-dir=s' => \$SRILM_DIR, + 'lmplz=s' => \$LMPLZ, 'out-dir=s' => \$OUT_DIR); # check if the files are in place -die("ERROR: you need to define --corpus-e, --corpus-f, --alignment, --srilm-dir, and --moses-src-dir") +die("ERROR: you need to define --corpus-e, --corpus-f, --alignment, --srilm-dir or --lmplz, and --moses-src-dir") unless (defined($MOSES_SRC_DIR) && defined($CORPUS_F) && defined($CORPUS_E) && defined($ALIGNMENT)&& - defined($SRILM_DIR)); + (defined($SRILM_DIR) || defined($LMPLZ))); die("ERROR: could not find input corpus file '$CORPUS_F'") unless -e $CORPUS_F; die("ERROR: could not find output corpus file '$CORPUS_E'") @@ -87,7 +88,12 @@ print "Converting Bilingual Sentence Pair into Operation Corpus\n"; `$MOSES_SRC_DIR/bin/generateSequences $OUT_DIR/$factor_val/e $OUT_DIR/$factor_val/f $OUT_DIR/align $OUT_DIR/$factor_val/Singletons > $OUT_DIR/$factor_val/opCorpus`; print "Learning Operation Sequence Translation Model\n"; -`$SRILM_DIR/ngram-count -kndiscount -order $ORDER -unk -text $OUT_DIR/$factor_val/opCorpus -lm $OUT_DIR/$factor_val/operationLM`; +if (defined($LMPLZ)) { + `$LMPLZ --order $ORDER --text $OUT_DIR/$factor_val/opCorpus --arpa $OUT_DIR/$factor_val/operationLM --prune 0 0 1`; +} +else { + `$SRILM_DIR/ngram-count -kndiscount -order $ORDER -unk -text $OUT_DIR/$factor_val/opCorpus -lm $OUT_DIR/$factor_val/operationLM`; +} print "Binarizing\n"; `$MOSES_SRC_DIR/bin/build_binary $OUT_DIR/$factor_val/operationLM $OUT_DIR/$factor_val/operationLM.bin`; diff --git a/scripts/ems/example/config.basic b/scripts/ems/example/config.basic index 07bb4f8c4..1db8154f5 100644 --- a/scripts/ems/example/config.basic +++ b/scripts/ems/example/config.basic @@ -382,7 +382,7 @@ alignment-symmetrization-method = grow-diag-final-and # #operation-sequence-model = "yes" #operation-sequence-model-order = 5 -#operation-sequence-model-settings = "" +#operation-sequence-model-settings = "-lmplz '$moses-src-dir/bin/lmplz -S 40% -T $working-dir/model/tmp'" ### if OSM training should be skipped, # point to OSM Model diff --git a/scripts/ems/example/config.factored b/scripts/ems/example/config.factored index bbb4bf793..c3a6b2a85 100644 --- a/scripts/ems/example/config.factored +++ b/scripts/ems/example/config.factored @@ -402,7 +402,7 @@ alignment-symmetrization-method = grow-diag-final-and # #operation-sequence-model = "yes" #operation-sequence-model-order = 5 -#operation-sequence-model-settings = "" +#operation-sequence-model-settings = "-lmplz '$moses-src-dir/bin/lmplz -S 40% -T $working-dir/model/tmp'" ### if OSM training should be skipped, # point to OSM Model diff --git a/scripts/ems/example/config.hierarchical b/scripts/ems/example/config.hierarchical index 8be19e2d7..673ad64a9 100644 --- a/scripts/ems/example/config.hierarchical +++ b/scripts/ems/example/config.hierarchical @@ -382,7 +382,7 @@ alignment-symmetrization-method = grow-diag-final-and # #operation-sequence-model = "yes" #operation-sequence-model-order = 5 -#operation-sequence-model-settings = "" +#operation-sequence-model-settings = "-lmplz '$moses-src-dir/bin/lmplz -S 40% -T $working-dir/model/tmp'" ### if OSM training should be skipped, # point to OSM Model diff --git a/scripts/ems/example/config.syntax b/scripts/ems/example/config.syntax index e9e9d5e47..7df60f990 100644 --- a/scripts/ems/example/config.syntax +++ b/scripts/ems/example/config.syntax @@ -386,7 +386,7 @@ alignment-symmetrization-method = grow-diag-final-and # #operation-sequence-model = "yes" #operation-sequence-model-order = 5 -#operation-sequence-model-settings = "" +#operation-sequence-model-settings = "-lmplz '$moses-src-dir/bin/lmplz -S 40% -T $working-dir/model/tmp'" ### if OSM training should be skipped, # point to OSM Model diff --git a/scripts/ems/experiment.meta b/scripts/ems/experiment.meta index d9715a8c9..9785d8940 100644 --- a/scripts/ems/experiment.meta +++ b/scripts/ems/experiment.meta @@ -150,8 +150,14 @@ tokenize pass-unless: output-tokenizer template: $output-tokenizer < IN > OUT parallelizable: yes -factorize +mock-parse in: tokenized-corpus + out: mock-parsed-corpus + default-name: lm/mock-parsed + pass-unless: mock-output-parser-lm + template: $mock-output-parser-lm < IN > OUT +factorize + in: mock-parsed-corpus out: factorized-corpus rerun-on-change: TRAINING:output-factors default-name: lm/factored @@ -234,8 +240,14 @@ tokenize-tuning pass-unless: output-tokenizer template: $output-tokenizer < IN > OUT parallelizable: yes -factorize-tuning +mock-parse-tuning in: tokenized-tuning + out: mock-parsed-tuning + default-name: lm/interpolate-tuning.mock-parsed + pass-unless: mock-output-parser-lm + template: $mock-output-parser-lm < IN > OUT +factorize-tuning + in: mock-parsed-tuning out: factorized-tuning default-name: lm/interpolate-tuning.factored pass-unless: TRAINING:output-factors @@ -705,17 +717,32 @@ tokenize-input-devtest pass-unless: input-tokenizer ignore-unless: use-mira template: $input-tokenizer < IN > OUT -parse-input +mock-parse-input in: tokenized-input + out: mock-parsed-input + default-name: tuning/input.mock-parsed + pass-unless: mock-input-parser-devtesteval + template: $mock-input-parser-devtesteval < IN > OUT +mock-parse-input-devtest + in: tokenized-input-devtest + out: mock-parsed-input-devtest + default-name: tuning/input.devtest.mock-parsed + pass-unless: mock-input-parser-devtesteval + ignore-unless: use-mira + template: $mock-input-parser-devtesteval < IN > OUT +parse-input + in: mock-parsed-input out: parsed-input default-name: tuning/input.parsed pass-unless: input-parser + pass-if: skip-parse-input-devtesteval mock-input-parser-devtesteval template: $input-parser < IN > OUT parse-input-devtest - in: tokenized-input-devtest + in: mock-parsed-input-devtesteval out: parsed-input-devtest default-name: tuning/input.devtest.parsed pass-unless: input-parser + pass-if: skip-parse-input-devtesteval mock-input-parser-devtesteval ignore-unless: use-mira template: $input-parser < IN > OUT parse-relax-input @@ -723,14 +750,16 @@ parse-relax-input out: parse-relaxed-input default-name: tuning/input.parse-relaxed pass-unless: input-parse-relaxer - template: $input-parse-relaxer < IN.$input-extension > OUT.$input-extension + pass-if: skip-parse-input-devtesteval mock-input-parser-devtesteval + template: $input-parse-relaxer < IN.$input-extension > OUT.$input-extension parse-relax-input-devtest in: parsed-input-devtest out: parse-relaxed-input-devtest default-name: tuning/input.devtest.parse-relaxed pass-unless: input-parse-relaxer + pass-if: skip-parse-input-devtesteval mock-input-parser-devtesteval ignore-unless: use-mira - template: $input-parse-relaxer < IN.$input-extension > OUT.$input-extension + template: $input-parse-relaxer < IN.$input-extension > OUT.$input-extension factorize-input in: parse-relaxed-input out: factorized-input @@ -832,8 +861,20 @@ tokenize-reference-devtest ignore-unless: use-mira multiref: $moses-script-dir/ems/support/run-command-on-multiple-refsets.perl template: $output-tokenizer < IN > OUT -lowercase-reference +mock-parse-reference in: tokenized-reference + out: mock-parsed-reference + default-name: tuning/reference.mock-parsed + pass-unless: mock-output-parser-references + template: $mock-output-parser-references < IN > OUT +mock-parse-reference-devtest + in: tokenized-input-devtest + out: mock-parsed-reference-devtest + default-name: tuning/reference.devtest.mock-parsed + pass-unless: mock-output-parser-references + template: $mock-output-parser-references < IN > OUT +lowercase-reference + in: mock-parsed-reference out: truecased-reference default-name: tuning/reference.lc pass-unless: output-lowercaser @@ -841,7 +882,7 @@ lowercase-reference multiref: $moses-script-dir/ems/support/run-command-on-multiple-refsets.perl template: $output-lowercaser < IN > OUT lowercase-reference-devtest - in: tokenized-reference-devtest + in: mock-parsed-reference-devtest out: truecased-reference-devtest default-name: tuning/reference.devtest.lc pass-unless: output-lowercaser @@ -850,7 +891,7 @@ lowercase-reference-devtest multiref: $moses-script-dir/ems/support/run-command-on-multiple-refsets.perl template: $output-lowercaser < IN > OUT truecase-reference - in: tokenized-reference TRUECASER:truecase-model + in: mock-parsed-reference TRUECASER:truecase-model out: truecased-reference rerun-on-change: output-truecaser default-name: tuning/reference.tc @@ -858,7 +899,7 @@ truecase-reference multiref: $moses-script-dir/ems/support/run-command-on-multiple-refsets.perl template: $output-truecaser -model IN1.$output-extension < IN > OUT truecase-reference-devtest - in: tokenized-reference-devtest TRUECASER:truecase-model + in: mock-parsed-reference-devtest TRUECASER:truecase-model out: truecased-reference-devtest rerun-on-change: output-truecaser default-name: tuning/reference.devtest.tc @@ -959,18 +1000,26 @@ tokenize-input default-name: evaluation/input.tok pass-unless: input-tokenizer template: $input-tokenizer < IN > OUT -parse-input +mock-parse-input in: tokenized-input + out: mock-parsed-input + default-name: evaluation/input.mock-parsed + pass-unless: mock-input-parser-devtesteval + template: $mock-input-parser-devtesteval < IN > OUT +parse-input + in: mock-parsed-input out: parsed-input default-name: evaluation/input.parsed pass-unless: input-parser + pass-if: skip-parse-input-devtesteval mock-input-parser-devtesteval template: $input-parser < IN > OUT parse-relax-input in: parsed-input out: parse-relaxed-input default-name: tuning/input.parse-relaxed pass-unless: input-parse-relaxer - template: $input-parse-relaxer < IN.$input-extension > OUT.$input-extension + pass-if: skip-parse-input-devtesteval mock-input-parser-devtesteval + template: $input-parse-relaxer < IN.$input-extension > OUT.$input-extension factorize-input in: parse-relaxed-input out: factorized-input @@ -1093,8 +1142,14 @@ tokenize-reference pass-unless: output-tokenizer multiref: $moses-script-dir/ems/support/run-command-on-multiple-refsets.perl template: $output-tokenizer < IN > OUT -lowercase-reference +mock-parse-reference in: tokenized-reference + out: mock-parsed-reference + default-name: evaluation/reference.mock-parsed + pass-unless: mock-output-parser-references + template: $mock-output-parser-references < IN > OUT +lowercase-reference + in: mock-parsed-reference out: reference default-name: evaluation/reference pass-unless: output-lowercaser diff --git a/scripts/ems/experiment.perl b/scripts/ems/experiment.perl index 3f4e53f23..4f67a6d8a 100755 --- a/scripts/ems/experiment.perl +++ b/scripts/ems/experiment.perl @@ -101,7 +101,7 @@ $VERSION = $DELETE_VERSION if $DELETE_VERSION; `mkdir -p steps/$VERSION` unless -d "steps/$VERSION"; &log_config() unless $DELETE_CRASHED || $DELETE_VERSION; -print "running experimenal run number $VERSION\n"; +print "running experimental run number $VERSION\n"; print "\nESTABLISH WHICH STEPS NEED TO BE RUN\n"; my (%NEEDED, # mapping of input files to step numbers @@ -2406,24 +2406,16 @@ sub define_training_create_config { $cmd .= "-transliteration-phrase-table $transliteration_pt "; } - if($osm){ - + if ($osm) { my $osm_settings = &get("TRAINING:operation-sequence-model-settings"); - - - if($osm_settings =~ /factor/){ - - $cmd .= "-osm-model $osm/ "; - my $find = "--factor"; - my $replace = "-osm-setting"; - $osm_settings =~ s/$find/$replace/g; - $cmd .= "$osm_settings "; - } - else{ - $cmd .= "-osm-model $osm/operationLM.bin "; - } + if ($osm_settings =~ /-factor *(\S+)/){ + $cmd .= "-osm-model $osm/ -osm-setting $1 "; + } + else { + $cmd .= "-osm-model $osm/operationLM.bin "; + } } - + # sparse lexical features provide additional content for config file $cmd .= "-additional-ini-file $sparse_lexical_features.ini " if $sparse_lexical_features; diff --git a/scripts/ems/support/defaultconfig.py b/scripts/ems/support/defaultconfig.py index 5d5187c47..e88b63e3d 100644 --- a/scripts/ems/support/defaultconfig.py +++ b/scripts/ems/support/defaultconfig.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python2 # # Version of ConfigParser which accepts default values diff --git a/scripts/ems/support/mml-filter.py b/scripts/ems/support/mml-filter.py index 437c9dade..5fb43d71e 100755 --- a/scripts/ems/support/mml-filter.py +++ b/scripts/ems/support/mml-filter.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python2 # # Filter a parallel corpus diff --git a/scripts/ems/support/wrap-xml.perl b/scripts/ems/support/wrap-xml.perl index 4ef6a1de6..beeca6cdd 100755 --- a/scripts/ems/support/wrap-xml.perl +++ b/scripts/ems/support/wrap-xml.perl @@ -10,6 +10,7 @@ open(SRC,$src) or die "Cannot open: $!"; my @OUT = ; chomp(@OUT); #my @OUT = `cat $decoder_output`; +my $missing_end_seg = 0; while() { chomp; if (/^) { $line = "" if $line =~ /NO BEST TRANSLATION/; if (/<\/seg>/) { s/(]+> *).*(<\/seg>)/$1$line$2/i; + $missing_end_seg = 0; } else { - s/(]+> *)[^<]*/$1$line/i; + s/(]+> *)[^<]*/$1$line<\/seg>/i; + $missing_end_seg = 1; } } + elsif ($missing_end_seg) { + if (/<\/doc>/) { + $missing_end_seg = 0; + } + else { + next; + } + } print $_."\n"; } diff --git a/scripts/generic/extract-parallel.perl b/scripts/generic/extract-parallel.perl index b663dcfe8..7abada1de 100755 --- a/scripts/generic/extract-parallel.perl +++ b/scripts/generic/extract-parallel.perl @@ -64,20 +64,20 @@ my $pid; if ($numParallel > 1) { - $cmd = "$splitCmd -d -l $linesPerSplit -a 5 $target $TMPDIR/target."; + $cmd = "$splitCmd -d -l $linesPerSplit -a 7 $target $TMPDIR/target."; $pid = RunFork($cmd); push(@children, $pid); - $cmd = "$splitCmd -d -l $linesPerSplit -a 5 $source $TMPDIR/source."; + $cmd = "$splitCmd -d -l $linesPerSplit -a 7 $source $TMPDIR/source."; $pid = RunFork($cmd); push(@children, $pid); - $cmd = "$splitCmd -d -l $linesPerSplit -a 5 $align $TMPDIR/align."; + $cmd = "$splitCmd -d -l $linesPerSplit -a 7 $align $TMPDIR/align."; $pid = RunFork($cmd); push(@children, $pid); if ($weights) { - $cmd = "$splitCmd -d -l $linesPerSplit -a 5 $weights $TMPDIR/weights."; + $cmd = "$splitCmd -d -l $linesPerSplit -a 7 $weights $TMPDIR/weights."; $pid = RunFork($cmd); push(@children, $pid); } @@ -259,15 +259,21 @@ sub NumStr($) my $i = shift; my $numStr; if ($i < 10) { - $numStr = "0000$i"; + $numStr = "000000$i"; } elsif ($i < 100) { - $numStr = "000$i"; + $numStr = "00000$i"; } elsif ($i < 1000) { - $numStr = "00$i"; + $numStr = "0000$i"; } elsif ($i < 10000) { + $numStr = "000$i"; + } + elsif ($i < 100000) { + $numStr = "00$i"; + } + elsif ($i < 1000000) { $numStr = "0$i"; } else { diff --git a/scripts/generic/generic-parallel.perl b/scripts/generic/generic-parallel.perl index 594fbcf5d..2becba31c 100755 --- a/scripts/generic/generic-parallel.perl +++ b/scripts/generic/generic-parallel.perl @@ -90,19 +90,25 @@ sub NumStr($) my $i = shift; my $numStr; if ($i < 10) { - $numStr = "0000$i"; + $numStr = "000000$i"; } elsif ($i < 100) { - $numStr = "000$i"; + $numStr = "00000$i"; } elsif ($i < 1000) { - $numStr = "00$i"; + $numStr = "0000$i"; } elsif ($i < 10000) { - $numStr = "0$i"; + $numStr = "000$i"; + } + elsif ($i < 100000) { + $numStr = "00$i"; + } + elsif ($i < 1000000) { + $numStr = "0$i"; } else { - $numStr = $i; + $numStr = $i; } return $numStr; } diff --git a/scripts/generic/giza-parallel.perl b/scripts/generic/giza-parallel.perl index 60059b46a..55192af74 100755 --- a/scripts/generic/giza-parallel.perl +++ b/scripts/generic/giza-parallel.perl @@ -102,23 +102,29 @@ print $cmd; sub NumStr($) { - my $i = shift; - my $numStr; - if ($i < 10) { - $numStr = "0000$i"; - } - elsif ($i < 100) { - $numStr = "000$i"; - } - elsif ($i < 1000) { - $numStr = "00$i"; - } - elsif ($i < 10000) { - $numStr = "0$i"; - } - else { - $numStr = $i; - } - return $numStr; + my $i = shift; + my $numStr; + if ($i < 10) { + $numStr = "000000$i"; + } + elsif ($i < 100) { + $numStr = "00000$i"; + } + elsif ($i < 1000) { + $numStr = "0000$i"; + } + elsif ($i < 10000) { + $numStr = "000$i"; + } + elsif ($i < 100000) { + $numStr = "00$i"; + } + elsif ($i < 1000000) { + $numStr = "0$i"; + } + else { + $numStr = $i; + } + return $numStr; } diff --git a/scripts/generic/score-parallel.perl b/scripts/generic/score-parallel.perl index da37b1353..213b9e90e 100755 --- a/scripts/generic/score-parallel.perl +++ b/scripts/generic/score-parallel.perl @@ -305,15 +305,21 @@ sub NumStr($) my $i = shift; my $numStr; if ($i < 10) { - $numStr = "0000$i"; + $numStr = "000000$i"; } elsif ($i < 100) { - $numStr = "000$i"; + $numStr = "00000$i"; } elsif ($i < 1000) { - $numStr = "00$i"; + $numStr = "0000$i"; } elsif ($i < 10000) { + $numStr = "000$i"; + } + elsif ($i < 100000) { + $numStr = "00$i"; + } + elsif ($i < 1000000) { $numStr = "0$i"; } else { diff --git a/scripts/generic/strip-xml.perl b/scripts/generic/strip-xml.perl index 9fc43d4d9..40a61302a 100755 --- a/scripts/generic/strip-xml.perl +++ b/scripts/generic/strip-xml.perl @@ -9,13 +9,14 @@ while (my $line = ) { my $len = length($line); my $inXML = 0; my $prevSpace = 1; + my $prevBar = 0; for (my $i = 0; $i < $len; ++$i) { my $c = substr($line, $i, 1); - if ($c eq "<") { + if ($c eq "<" && !$prevBar) { ++$inXML; } - elsif ($c eq ">") { + elsif ($c eq ">" && $inXML>0) { --$inXML; } elsif ($prevSpace == 1 && $c eq " ") @@ -24,9 +25,15 @@ while (my $line = ) { elsif ($inXML == 0) { if ($c eq " ") { $prevSpace = 1; + $prevBar = 0; + } + elsif ($c eq "|") { + $prevSpace = 0; + $prevBar = 1; } else { $prevSpace = 0; + $prevBar = 0; } print $c; } diff --git a/scripts/other/delete-scores.perl b/scripts/other/delete-scores.perl index 442173026..2a4f51c89 100755 --- a/scripts/other/delete-scores.perl +++ b/scripts/other/delete-scores.perl @@ -59,3 +59,5 @@ sub DeleteScore return $string; } + + diff --git a/scripts/other/retain-lines.perl b/scripts/other/retain-lines.perl new file mode 100755 index 000000000..6f7c517c2 --- /dev/null +++ b/scripts/other/retain-lines.perl @@ -0,0 +1,31 @@ +#!/usr/bin/perl + +#retain lines in clean.lines-retained.1 +use strict; +use warnings; + +binmode(STDIN, ":utf8"); +binmode(STDOUT, ":utf8"); +binmode(STDERR, ":utf8"); + +my $retainPath = $ARGV[0]; + +open(LINE_RETAINED, $retainPath); +my $retainLine = ; + +my $lineNum = 0; +while (my $line = ) { + chomp($line); + ++$lineNum; + + if ($retainLine == $lineNum) { + print "$line\n"; + if ($retainLine = ) { + # do nothing + } + else { + # retained lines is finished. + $retainLine = 0; + } + } +} diff --git a/scripts/recaser/train-recaser.perl b/scripts/recaser/train-recaser.perl index f36e35dcc..fa833dbd6 100755 --- a/scripts/recaser/train-recaser.perl +++ b/scripts/recaser/train-recaser.perl @@ -10,8 +10,9 @@ binmode(STDOUT, ":utf8"); # apply switches my ($DIR,$CORPUS,$SCRIPTS_ROOT_DIR,$CONFIG,$HELP,$ERROR); -my $LM = "SRILM"; # SRILM is default. +my $LM = "KENLM"; # KENLM is default. my $BUILD_LM = "build-lm.sh"; +my $BUILD_KENLM = "$Bin/../../bin/lmplz"; my $NGRAM_COUNT = "ngram-count"; my $TRAIN_SCRIPT = "train-factored-phrase-model.perl"; my $MAX_LEN = 1; @@ -25,6 +26,7 @@ $ERROR = "training Aborted." 'dir=s' => \$DIR, 'ngram-count=s' => \$NGRAM_COUNT, 'build-lm=s' => \$BUILD_LM, + 'build-kenlm=s' => \$BUILD_KENLM, 'lm=s' => \$LM, 'train-script=s' => \$TRAIN_SCRIPT, 'scripts-root-dir=s' => \$SCRIPTS_ROOT_DIR, @@ -55,7 +57,7 @@ if ($HELP || $ERROR) { --max-len=int ... max phrase length (default: 1). = Language Model Training configuration = - --lm=[IRSTLM,SRILM] ... language model (default: SRILM). + --lm=[IRSTLM,SRILM,KENLM] ... language model (default: KENLM). --build-lm=file ... path to build-lm.sh if not in \$PATH (used only with --lm=IRSTLM). --ngram-count=file ... path to ngram-count.sh if not in \$PATH (used only with --lm=SRILM). @@ -110,10 +112,14 @@ sub train_lm { if (uc $LM eq "IRSTLM") { $cmd = "$BUILD_LM -t /tmp -i $CORPUS -n 3 -o $DIR/cased.irstlm.gz"; } - else { + elsif (uc $LM eq "SRILM") { $LM = "SRILM"; $cmd = "$NGRAM_COUNT -text $CORPUS -lm $DIR/cased.srilm.gz -interpolate -kndiscount"; } + else { + $LM = "KENLM"; + $cmd = "$BUILD_KENLM --prune 0 0 1 -S 50% -T $DIR/lmtmp --order 3 --text $CORPUS --arpa $DIR/cased.kenlm.gz"; + } print STDERR "** Using $LM **" . "\n"; print STDERR $cmd."\n"; system($cmd) == 0 || die("Language model training failed with error " . ($? >> 8) . "\n"); @@ -160,9 +166,12 @@ sub train_recase_model { if (uc $LM eq "IRSTLM") { $cmd .= " --lm 0:3:$DIR/cased.irstlm.gz:1"; } - else { + elsif (uc $LM eq "SRILM") { $cmd .= " --lm 0:3:$DIR/cased.srilm.gz:8"; } + else { + $cmd .= " --lm 0:3:$DIR/cased.kenlm.gz:8"; + } $cmd .= " -config $CONFIG" if $CONFIG; print STDERR $cmd."\n"; system($cmd) == 0 || die("Recaser model training failed with error " . ($? >> 8) . "\n"); diff --git a/scripts/recaser/truecase.perl b/scripts/recaser/truecase.perl index 74b55045b..0a4d366e0 100755 --- a/scripts/recaser/truecase.perl +++ b/scripts/recaser/truecase.perl @@ -35,7 +35,7 @@ while() { my ($WORD,$MARKUP) = split_xml($_); my $sentence_start = 1; for(my $i=0;$i<=$#$WORD;$i++) { - print " " if $i; + print " " if $i && $$MARKUP[$i] eq ''; print $$MARKUP[$i]; my ($word,$otherfactors); @@ -67,7 +67,7 @@ while() { if ( defined($SENTENCE_END{ $word })) { $sentence_start = 1; } elsif (!defined($DELAYED_SENTENCE_START{ $word })) { $sentence_start = 0; } } - print " ".$$MARKUP[$#$MARKUP]; + print $$MARKUP[$#$MARKUP]; print "\n"; } diff --git a/scripts/server/moses.py b/scripts/server/moses.py index 32c53fa2a..155458b9b 100644 --- a/scripts/server/moses.py +++ b/scripts/server/moses.py @@ -31,7 +31,7 @@ class ProcessWrapper: def start(self, stdin=PIPE, stdout=PIPE): if self.process: raise Exception("Process is already running") - self.process = Popen(cmd, stdin = stdin, stdout = stdout) + self.process = Popen(self.cmd, stdin = stdin, stdout = stdout) return def __del__(self): @@ -57,6 +57,7 @@ class SentenceSplitter(ProcessWrapper): def __init__(self,lang): ssplit_cmd = moses_root+"/scripts/ems/support/split-sentences.perl" self.cmd = [ssplit_cmd, "-b", "-q", "-l",lang] + self.process = None return def __call__(self,input): @@ -91,15 +92,17 @@ class Tokenizer(LineProcessor): def __init__(self,lang,args=["-a","-no-escape"]): tok_cmd = moses_root+"/scripts/tokenizer/tokenizer.perl" self.cmd = [tok_cmd,"-b", "-q", "-l", lang] + args + self.process = None return -class TrueCaser(LineProcessor): +class Truecaser(LineProcessor): """ Truecaser wrapper. """ def __init__(self,model): - trucase_cmd = moses_root+"/scripts/recaser/truecase.perl" + truecase_cmd = moses_root+"/scripts/recaser/truecase.perl" self.cmd = [truecase_cmd,"-b", "--model",model] + self.process = None return pass diff --git a/scripts/training/mert-moses.pl b/scripts/training/mert-moses.pl index d24d11ef9..d1ac5828a 100755 --- a/scripts/training/mert-moses.pl +++ b/scripts/training/mert-moses.pl @@ -863,7 +863,8 @@ while (1) { $mira_settings .= "$batch_mira_args "; } - $mira_settings .= " --dense-init run$run.dense"; + $mira_settings .= " --dense-init run$run.$weights_in_file"; + #$mira_settings .= " --dense-init run$run.dense"; if (-e "run$run.sparse-weights") { $mira_settings .= " --sparse-init run$run.sparse-weights"; } @@ -1237,6 +1238,7 @@ sub run_decoder { $decoder_cmd = "$___DECODER $___DECODER_FLAGS -config $___CONFIG -inputtype $___INPUTTYPE $decoder_config $lsamp_cmd $nbest_list_cmd -input-file $___DEV_F > run$run.out"; } + print STDERR "Executing: $decoder_cmd \n"; safesystem($decoder_cmd) or die "The decoder died. CONFIG WAS $decoder_config \n"; if (!$___HG_MIRA) { @@ -1308,6 +1310,7 @@ sub get_featlist_from_moses { } else { print STDERR "Asking moses for feature names and values from $___CONFIG\n"; my $cmd = "$___DECODER $___DECODER_FLAGS -config $configfn -inputtype $___INPUTTYPE -show-weights > $featlistfn"; + print STDERR "Executing: $cmd\n"; safesystem($cmd) or die "Failed to run moses with the config $configfn"; } return get_featlist_from_file($featlistfn); diff --git a/scripts/training/reduce-factors.perl b/scripts/training/reduce-factors.perl index fd4906a48..c7269abf9 100755 --- a/scripts/training/reduce-factors.perl +++ b/scripts/training/reduce-factors.perl @@ -47,7 +47,9 @@ sub reduce_factors { $firstline =~ s/^\s*//; $firstline =~ s/\s.*//; # count factors - my $maxfactorindex = $firstline =~ tr/|/|/; + my @WORD = split(/ /,$firstline); + my @FACTOR = split(/$___FACTOR_DELIMITER/,$WORD[0]); + my $maxfactorindex = scalar(@FACTOR)-1; if (join(",", @INCLUDE) eq join(",", 0..$maxfactorindex)) { # create just symlink; preserving compression my $realfull = $full; @@ -107,3 +109,24 @@ sub open_or_zcat { open($hdl,$read) or die "Can't read $fn ($read)"; return $hdl; } + +sub safesystem { + print STDERR "Executing: @_\n"; + system(@_); + if ($? == -1) { + print STDERR "ERROR: Failed to execute: @_\n $!\n"; + exit(1); + } + elsif ($? & 127) { + printf STDERR "ERROR: Execution of: @_\n died with signal %d, %s coredump\n", + ($? & 127), ($? & 128) ? 'with' : 'without'; + exit(1); + } + else { + my $exitcode = $? >> 8; + print STDERR "Exit code: $exitcode\n" if $exitcode; + return ! $exitcode; + } +} + + diff --git a/scripts/training/train-model.perl b/scripts/training/train-model.perl index 46a7e1fe6..a9ed58535 100755 --- a/scripts/training/train-model.perl +++ b/scripts/training/train-model.perl @@ -189,6 +189,7 @@ $_GIZA_F2E = File::Spec->rel2abs($_GIZA_F2E) if defined($_GIZA_F2E); my $_SCORE_OPTIONS; # allow multiple switches foreach (@_SCORE_OPTIONS) { $_SCORE_OPTIONS .= $_." "; } chop($_SCORE_OPTIONS) if $_SCORE_OPTIONS; + my $_EXTRACT_OPTIONS; # allow multiple switches foreach (@_EXTRACT_OPTIONS) { $_EXTRACT_OPTIONS .= $_." "; } chop($_EXTRACT_OPTIONS) if $_EXTRACT_OPTIONS; @@ -754,7 +755,9 @@ sub reduce_factors { $firstline =~ s/^\s*//; $firstline =~ s/\s.*//; # count factors - my $maxfactorindex = $firstline =~ tr/$___FACTOR_DELIMITER/$___FACTOR_DELIMITER/; + my @WORD = split(/ /,$firstline); + my @FACTOR = split(/$___FACTOR_DELIMITER/,$WORD[0]); + my $maxfactorindex = scalar(@FACTOR)-1; if (join(",", @INCLUDE) eq join(",", 0..$maxfactorindex)) { # create just symlink; preserving compression my $realfull = $full; @@ -1546,6 +1549,7 @@ sub score_phrase_phrase_extract { my $NEG_LOG_PROB = (defined($_SCORE_OPTIONS) && $_SCORE_OPTIONS =~ /NegLogProb/); my $NO_LEX = (defined($_SCORE_OPTIONS) && $_SCORE_OPTIONS =~ /NoLex/); my $MIN_COUNT_HIERARCHICAL = (defined($_SCORE_OPTIONS) && $_SCORE_OPTIONS =~ /MinCountHierarchical ([\d\.]+)/) ? $1 : undef; + my $SPAN_LENGTH = (defined($_SCORE_OPTIONS) && $_SCORE_OPTIONS =~ /SpanLength/); my $CORE_SCORE_OPTIONS = ""; $CORE_SCORE_OPTIONS .= " --LogProb" if $LOG_PROB; $CORE_SCORE_OPTIONS .= " --NegLogProb" if $NEG_LOG_PROB; @@ -1584,6 +1588,7 @@ sub score_phrase_phrase_extract { $cmd .= " --NoWordAlignment" if $_OMIT_WORD_ALIGNMENT; $cmd .= " --KneserNey" if $KNESER_NEY; $cmd .= " --GoodTuring" if $GOOD_TURING && $inverse eq ""; + $cmd .= " --SpanLength" if $SPAN_LENGTH && $inverse eq ""; $cmd .= " --UnalignedPenalty" if $UNALIGNED_COUNT; $cmd .= " --UnalignedFunctionWordPenalty ".($inverse ? $UNALIGNED_FW_F : $UNALIGNED_FW_E) if $UNALIGNED_FW_COUNT; $cmd .= " --MinCountHierarchical $MIN_COUNT_HIERARCHICAL" if $MIN_COUNT_HIERARCHICAL; diff --git a/scripts/training/wrappers/conll2mosesxml.py b/scripts/training/wrappers/conll2mosesxml.py new file mode 100755 index 000000000..d85695b16 --- /dev/null +++ b/scripts/training/wrappers/conll2mosesxml.py @@ -0,0 +1,188 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# Author: Rico Sennrich + +# takes a file in the CoNLL dependency format (from the CoNLL-X shared task on dependency parsing; http://ilk.uvt.nl/conll/#dataformat ) +# and produces Moses XML format. Note that the structure is built based on fields 9 and 10 (projective HEAD and RELATION), +# which not all parsers produce. + +# usage: conll2mosesxml.py [--brackets] < input_file > output_file + +from __future__ import print_function, unicode_literals +import sys +import re +import codecs +from collections import namedtuple,defaultdict +from lxml import etree as ET + + +Word = namedtuple('Word', ['pos','word','lemma','tag','head','func', 'proj_head', 'proj_func']) + +def main(output_format='xml'): + sentence = [] + + for line in sys.stdin: + + # process sentence + if line == "\n": + sentence.insert(0,[]) + if is_projective(sentence): + write(sentence,output_format) + else: + sys.stderr.write(' '.join(w.word for w in sentence[1:]) + '\n') + sys.stdout.write('\n') + sentence = [] + continue + + try: + pos, word, lemma, tag, tag2, morph, head, func, proj_head, proj_func = line.split() + except ValueError: # word may be unicode whitespace + pos, word, lemma, tag, tag2, morph, head, func, proj_head, proj_func = re.split(' *\t*',line.strip()) + + word = escape_special_chars(word) + lemma = escape_special_chars(lemma) + + if proj_head == '_': + proj_head = head + proj_func = func + + sentence.append(Word(int(pos), word, lemma, tag2,int(head), func, int(proj_head), proj_func)) + + +# this script performs the same escaping as escape-special-chars.perl in Moses. +# most of it is done in function write(), but quotation marks need to be processed first +def escape_special_chars(line): + + line = line.replace('\'',''') # xml + line = line.replace('"','"') # xml + + return line + + +# make a check if structure is projective +def is_projective(sentence): + dominates = defaultdict(set) + for i,w in enumerate(sentence): + dominates[i].add(i) + if not i: + continue + head = int(w.proj_head) + while head != 0: + if i in dominates[head]: + break + dominates[head].add(i) + head = int(sentence[head].proj_head) + + for i in dominates: + dependents = dominates[i] + if max(dependents) - min(dependents) != len(dependents)-1: + sys.stderr.write("error: non-projective structure.\n") + return False + return True + + +def write(sentence, output_format='xml'): + + if output_format == 'xml': + tree = create_subtree(0,sentence) + out = ET.tostring(tree, encoding = 'UTF-8').decode('UTF-8') + + if output_format == 'brackets': + out = create_brackets(0,sentence) + + out = out.replace('|','|') # factor separator + out = out.replace('[','[') # syntax non-terminal + out = out.replace(']',']') # syntax non-terminal + + out = out.replace('&apos;',''') # lxml is buggy if input is escaped + out = out.replace('&quot;','"') # lxml is buggy if input is escaped + + print(out) + +# write node in Moses XML format +def create_subtree(position, sentence): + + element = ET.Element('tree') + + if position: + element.set('label', sentence[position].proj_func) + else: + element.set('label', 'sent') + + for i in range(1,position): + if sentence[i].proj_head == position: + element.append(create_subtree(i, sentence)) + + if position: + + if preterminals: + head = ET.Element('tree') + head.set('label', sentence[position].tag) + head.text = sentence[position].word + element.append(head) + + else: + if len(element): + element[-1].tail = sentence[position].word + else: + element.text = sentence[position].word + + for i in range(position, len(sentence)): + if i and sentence[i].proj_head == position: + element.append(create_subtree(i, sentence)) + + return element + + +# write node in bracket format (Penn treebank style) +def create_brackets(position, sentence): + + if position: + element = "( " + sentence[position].proj_func + ' ' + else: + element = "( sent " + + for i in range(1,position): + if sentence[i].proj_head == position: + element += create_brackets(i, sentence) + + if position: + word = sentence[position].word + if word == ')': + word = 'RBR' + elif word == '(': + word = 'LBR' + + tag = sentence[position].tag + if tag == '$(': + tag = '$BR' + + if preterminals: + element += '( ' + tag + ' ' + word + ' ) ' + else: + element += word + ' ) ' + + for i in range(position, len(sentence)): + if i and sentence[i].proj_head == position: + element += create_brackets(i, sentence) + + if preterminals or not position: + element += ') ' + + return element + +if __name__ == '__main__': + if sys.version_info < (3,0,0): + sys.stdin = codecs.getreader('UTF-8')(sys.stdin) + sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) + sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) + + if '--no_preterminals' in sys.argv: + preterminals = False + else: + preterminals = True + + if '--brackets' in sys.argv: + main('brackets') + else: + main('xml') diff --git a/scripts/training/wrappers/make-factor-brown-cluster-mkcls.perl b/scripts/training/wrappers/make-factor-brown-cluster-mkcls.perl index 60f341de8..13aa7f912 100755 --- a/scripts/training/wrappers/make-factor-brown-cluster-mkcls.perl +++ b/scripts/training/wrappers/make-factor-brown-cluster-mkcls.perl @@ -2,7 +2,7 @@ use strict; -my ($cluster_file,$in,$out,$tmp) = @ARGV; +my ($lowercase, $cluster_file,$in,$out,$tmp) = @ARGV; my $CLUSTER = &read_cluster_from_mkcls($cluster_file); @@ -17,7 +17,10 @@ while() { s/ $//; my $first = 1; foreach my $word (split) { - my $cluster = defined($$CLUSTER{$word}) ? $$CLUSTER{$word} : ""; + if ($lowercase) { + $word = lc($word); + } + my $cluster = defined($$CLUSTER{$word}) ? $$CLUSTER{$word} : "0"; print OUT " " unless $first; print OUT $cluster; $first = 0; diff --git a/scripts/training/wrappers/parse-de-berkeley.perl b/scripts/training/wrappers/parse-de-berkeley.perl index 5d4a4d313..03d90eaca 100755 --- a/scripts/training/wrappers/parse-de-berkeley.perl +++ b/scripts/training/wrappers/parse-de-berkeley.perl @@ -1,21 +1,28 @@ -#!/usr/bin/perl -w +#!/usr/bin/perl -w use strict; use Getopt::Long "GetOptions"; use FindBin qw($RealBin); -my ($JAR,$GRAMMAR,$SPLIT_HYPHEN,$SPLIT_SLASH,$MARK_SPLIT,$BINARIZE); +my ($JAR,$GRAMMAR,$SPLIT_HYPHEN,$SPLIT_SLASH,$MARK_SPLIT,$BINARIZE,$UNPARSEABLE); -die("ERROR: syntax is: parse-de-berkeley.perl [-split-hyphen] [-split-slash] [-mark-split] [-binarize] -jar jar-file -gr grammar < in > out\n") +$UNPARSEABLE = 0; + +die("ERROR: syntax is: parse-de-berkeley.perl [-split-hyphen] [-split-slash] [-mark-split] [-binarize] -jar jar-file -gr grammar -unparseable < in > out\n") unless &GetOptions ('jar=s' => \$JAR, 'gr=s' => \$GRAMMAR, 'split-hyphen' => \$SPLIT_HYPHEN, 'split-slash' => \$SPLIT_SLASH, 'mark-split' => \$MARK_SPLIT, - 'binarize' => \$BINARIZE) + 'binarize' => \$BINARIZE, + 'unparseable' => \$UNPARSEABLE + + ) && defined($JAR) && defined($GRAMMAR); +#print STDERR "UNPARSEABLE=$UNPARSEABLE\n"; + die("ERROR: could not find jar file '$JAR'\n") unless -e $JAR; die("ERROR: could not find grammar file '$GRAMMAR'\n") unless -e $GRAMMAR; @@ -26,9 +33,15 @@ $SPLIT_SLASH = $SPLIT_SLASH ? "| $RealBin/syntax-hyphen-splitting.perl -slash $B $SPLIT_SLASH .= " -mark-split" if $SPLIT_SLASH && $MARK_SPLIT; my $tmp = "/tmp/parse-de-berkeley.$$"; +my $tmpEscaped = "/tmp/parse-de-berkeley.2.$$"; +#print STDERR "tmp=$tmp\n"; +#print STDERR "tmpEscaped=$tmpEscaped\n"; open(TMP,"| $RealBin/../../tokenizer/deescape-special-chars.perl > $tmp"); +open(TMPESCAPED, ">>$tmpEscaped"); while() { + print TMPESCAPED $_; + # unsplit hyphens s/ \@-\@ /-/g if $SPLIT_HYPHEN; # unsplit slashes @@ -44,14 +57,30 @@ while() { print TMP $_; } close(TMP); +close(TMPESCAPED); my $cmd = "cat $tmp | java -Xmx10000m -Xms10000m -Dfile.encoding=UTF8 -jar $JAR -gr $GRAMMAR -maxLength 1000 $BINARIZE | $RealBin/berkeleyparsed2mosesxml.perl $SPLIT_HYPHEN $SPLIT_SLASH"; -print STDERR $cmd."\n"; +#print STDERR "Executing: $cmd \n"; + +open (TMP, $tmp); +open (TMPESCAPED, $tmpEscaped); open(PARSE,"$cmd|"); while() { s/\\\@/\@/g; - print $_; + my $outLine = $_; + my $unparsedLine = ; + + #print STDERR "unparsedLine=$unparsedLine"; + #print STDERR "outLine=$outLine" .length($outLine) ."\n"; + + if ($UNPARSEABLE == 1 && length($outLine) == 1) { + print $unparsedLine; + } + else { + print $outLine; + } } close(PARSE); `rm $tmp`; +`rm $tmpEscaped`; diff --git a/scripts/training/wrappers/parse-de-bitpar.perl b/scripts/training/wrappers/parse-de-bitpar.perl index 370187d32..f884b5c01 100755 --- a/scripts/training/wrappers/parse-de-bitpar.perl +++ b/scripts/training/wrappers/parse-de-bitpar.perl @@ -15,6 +15,7 @@ my $DEESCAPE = "$SCRIPTS_ROOT_DIR/tokenizer/deescape-special-chars.perl"; my $DEBUG = 0; my $BASIC = 0; my $OLD_BITPAR = 0; +my $UNPARSEABLE = 0; my $RAW = ""; @@ -22,7 +23,8 @@ GetOptions( "basic" => \$BASIC, "bitpar=s" => \$BITPAR, "old-bitpar" => \$OLD_BITPAR, - "raw=s" => \$RAW + "raw=s" => \$RAW, + "unparseable" => \$UNPARSEABLE ) or die("ERROR: unknown options"); `mkdir -p $TMPDIR`; @@ -71,6 +73,12 @@ if ($OLD_BITPAR) open(PARSER,$pipeline); while(my $line = ) { if ($line =~ /^No parse for/) { + if ($UNPARSEABLE) { + my $len = length($line); + $line = substr($line, 15, $len - 17); + $line = escape($line); + print $line; + } print "\n"; next; } diff --git a/scripts/training/wrappers/tagger-german-chunk.perl b/scripts/training/wrappers/tagger-german-chunk.perl new file mode 100755 index 000000000..1e4b5495d --- /dev/null +++ b/scripts/training/wrappers/tagger-german-chunk.perl @@ -0,0 +1,144 @@ +#!/usr/bin/perl + +use strict; +use Getopt::Long "GetOptions"; + +# split -a 5 -d ../europarl.clean.5.de +# ls -1 x????? | ~/workspace/coreutils/parallel/src/parallel /home/s0565741/workspace/treetagger/cmd/run-tagger-chunker-german.sh +# cat x?????.out > ../out + +my $chunkedPath; +my $treetaggerPath; + +GetOptions('chunked=s' => \$chunkedPath, + 'tree-tagger=s' => \$treetaggerPath); + +binmode(STDIN, ":utf8"); +binmode(STDOUT, ":utf8"); + +#my $TMPDIR= "/tmp/chunker.$$"; +my $TMPDIR= "chunker.$$"; +print STDERR "TMPDIR=$TMPDIR\n"; +print STDERR "chunkedPath=$chunkedPath\n"; +`mkdir $TMPDIR`; + +my $inPath = "$TMPDIR/in"; + +open(IN, ">$inPath"); +binmode(IN, ":utf8"); + +while(my $line = ) { + chomp($line); + print IN "$line\n"; +} +close(IN); + +# call chunker +if (!defined($chunkedPath)) { + if (!defined($treetaggerPath)) { + print STDERR "must defined -tree-tagger \n"; + exit(1); + } + + $chunkedPath = "$TMPDIR/chunked"; + print STDERR "chunkedPath not defined. Now $chunkedPath \n"; + my $cmd = "$treetaggerPath/cmd/tagger-chunker-german-utf8 < $inPath > $chunkedPath"; + `$cmd`; +} + +# convert chunked file into Moses XML +open(CHUNKED, "$chunkedPath"); +open(IN, "$inPath"); +binmode(CHUNKED, ":utf8"); +binmode(IN, ":utf8"); + +my $sentence = ; +chomp($sentence); +my @words = split(/ /, $sentence); +my $numWords = scalar @words; +my $prevTag = ""; +my $wordPos = -1; + +while(my $chunkLine = ) { + chomp($chunkLine); + my @chunkToks = split(/\t/, $chunkLine); + + if (substr($chunkLine, 0, 1) eq "<") { + if (substr($chunkLine, 0, 2) eq " "; + $prevTag = ""; + + if ($wordPos == ($numWords - 1)) { + # closing bracket of last word in sentence + print "\n"; + $sentence = ; + chomp($sentence); + @words = split(/ /, $sentence); + $numWords = scalar @words; + $wordPos = -1; + } + } + else { + # beginning of tag + if ($wordPos == ($numWords - 1)) { + # closing bracket of last word in sentence + print "\n"; + $sentence = ; + chomp($sentence); + @words = split(/ /, $sentence); + $numWords = scalar @words; + $wordPos = -1; + } + + $prevTag = $chunkToks[0]; + $prevTag = substr($prevTag, 1, length($prevTag) - 2); + print ""; + } + } + else { + # word + ++$wordPos; + + if (scalar(@chunkToks) != 3) { + # parse error + print STDERR "CHUNK LINES SHOULD BE 3 TOKS\n"; + exit(1); + } + + if ($wordPos >= $numWords) { + # on new sentence now + if (length($prevTag) > 0) { + print ""; + } + print "\n"; + if (length($prevTag) > 0) { + print ""; + } + + $sentence = ; + chomp($sentence); + @words = split(/ /, $sentence); + $numWords = scalar @words; + $wordPos = 0; + } + + if ($chunkToks[0] ne $words[$wordPos]) { + # word in chunk input and sentence should match + print STDERR "NOT EQUAL:" .$chunkToks[0] ." != " .$words[$wordPos] ."\n"; + exit(1); + } + + print $chunkToks[0] . " "; + + } + +} + +print "\n"; + +close(IN); +close(CHUNKED); + +`rm -rf $TMPDIR`; +
DateTimeTestnameRevisionBranchTimePrevtimePrevrevChange (%)Time (Basebranch)Change (%, Basebranch)Time (Days -2)Change (%, Days -2)Time (Days -3)Change (%, Days -3)Time (Days -4)Change (%, Days -4)Time (Days -5)Change (%, Days -5)Time (Days -6)Change (%, Days -6)Time (Days -7)Change (%, Days -7)Time (Days -14)Change (%, Days -14)Time (Years -1)Change (%, Years -1)
' + str(resline.previous) + '' + str(resline.percentage) + '' + str(resline.percentage) + '' + str(resline.percentage) + '
' + logLine2.date + '' + logLine2.time + '' +\ + res1.testname + '' + res1.revision[:10] + '' + res1.branch + '' +\ + str(res1.current) + '' + str(res1.previous) + '' + res1.prevrev[:10] + '' + str(res1.percentage) + '' + str(res1.percentage) + '' + str(res1.percentage) + '' + str(res2.previous) + '' + str(res2.percentage) + '' + str(res2.percentage) + '' + str(res2.percentage) + 'N/AN/A