mirror of
https://github.com/moses-smt/mosesdecoder.git
synced 2024-12-27 22:14:57 +03:00
gets right score
This commit is contained in:
parent
b590387ed9
commit
0d0f75b6e0
@ -22,7 +22,7 @@ namespace Moses2
|
||||
{
|
||||
|
||||
struct KenLMState : public FFState {
|
||||
lm::ngram::State state;
|
||||
lm::ngram::ChartState state;
|
||||
virtual size_t hash() const {
|
||||
size_t ret = hash_value(state);
|
||||
return ret;
|
||||
@ -35,11 +35,14 @@ struct KenLMState : public FFState {
|
||||
|
||||
virtual std::string ToString() const
|
||||
{
|
||||
/*
|
||||
stringstream ss;
|
||||
for (size_t i = 0; i < state.Length(); ++i) {
|
||||
ss << state.words[i] << " ";
|
||||
}
|
||||
return ss.str();
|
||||
*/
|
||||
return "KenLMState";
|
||||
}
|
||||
|
||||
};
|
||||
@ -132,7 +135,9 @@ void KENLM::EmptyHypothesisState(FFState &state,
|
||||
const Hypothesis &hypo) const
|
||||
{
|
||||
KenLMState &stateCast = static_cast<KenLMState&>(state);
|
||||
stateCast.state = m_ngram->BeginSentenceState();
|
||||
lm::ngram::RuleScore<Model> scorer(*m_ngram, stateCast.state);
|
||||
scorer.BeginSentence();
|
||||
scorer.Finish();
|
||||
}
|
||||
|
||||
void
|
||||
@ -179,24 +184,25 @@ void KENLM::EvaluateWhenApplied(const Manager &mgr,
|
||||
Scores &scores,
|
||||
FFState &state) const
|
||||
{
|
||||
const System &system = mgr.system;
|
||||
|
||||
const KenLMState &prevStateCast = static_cast<const KenLMState&>(prevState);
|
||||
KenLMState &stateCast = static_cast<KenLMState&>(state);
|
||||
|
||||
const System &system = mgr.system;
|
||||
|
||||
const lm::ngram::State &in_state = static_cast<const KenLMState&>(prevState).state;
|
||||
const lm::ngram::ChartState &prevKenState = prevStateCast.state;
|
||||
lm::ngram::ChartState &kenState = stateCast.state;
|
||||
|
||||
const TargetPhrase &tp = hypo.GetTargetPhrase();
|
||||
size_t tpSize = tp.GetSize();
|
||||
if (!tpSize) {
|
||||
stateCast.state = in_state;
|
||||
stateCast.state = prevKenState;
|
||||
return;
|
||||
}
|
||||
|
||||
// NEW CODE - start
|
||||
//const lm::ngram::ChartState &chartStateInIsolation = *static_cast<const lm::ngram::ChartState*>(tp.chartState);
|
||||
lm::ngram::ChartState newState;
|
||||
lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState);
|
||||
lm::ngram::RuleScore<Model> ruleScore(*m_ngram, kenState);
|
||||
ruleScore.NonTerminal(prevKenState, 0);
|
||||
|
||||
// each word in new tp
|
||||
for (size_t i = 0; i < tpSize; ++i) {
|
||||
@ -204,47 +210,21 @@ void KENLM::EvaluateWhenApplied(const Manager &mgr,
|
||||
lm::WordIndex lmInd = TranslateID(word);
|
||||
ruleScore.Terminal(lmInd);
|
||||
}
|
||||
float score = ruleScore.Finish();
|
||||
stateCast.state = newState.right;
|
||||
Model::State *state0 = &stateCast.state;
|
||||
// NEW CODE - end
|
||||
|
||||
const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
|
||||
//[begin, end) in STL-like fashion.
|
||||
const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
|
||||
const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1);
|
||||
|
||||
/*
|
||||
* OLD CODE
|
||||
std::size_t position = begin;
|
||||
|
||||
typename Model::State aux_state;
|
||||
typename Model::State *state0 = &stateCast.state, *state1 = &aux_state;
|
||||
|
||||
float score = m_ngram->Score(in_state, TranslateID(hypo.GetWord(position)), *state0);
|
||||
++position;
|
||||
for (; position < adjust_end; ++position) {
|
||||
score += m_ngram->Score(*state0, TranslateID(hypo.GetWord(position)), *state1);
|
||||
std::swap(state0, state1);
|
||||
}
|
||||
*/
|
||||
|
||||
if (hypo.GetBitmap().IsComplete()) {
|
||||
// Score end of sentence.
|
||||
std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
|
||||
const lm::WordIndex *last = LastIDs(hypo, &indices.front());
|
||||
score += m_ngram->FullScoreForgotState(&indices.front(), last, m_ngram->GetVocabulary().EndSentence(), stateCast.state).prob;
|
||||
} else if (adjust_end < end) {
|
||||
// Get state after adding a long phrase.
|
||||
std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
|
||||
const lm::WordIndex *last = LastIDs(hypo, &indices.front());
|
||||
m_ngram->GetState(&indices.front(), last, stateCast.state);
|
||||
} else if (state0 != &stateCast.state) {
|
||||
// Short enough phrase that we can just reuse the state.
|
||||
stateCast.state = *state0;
|
||||
ruleScore.Terminal(m_ngram->GetVocabulary().EndSentence());
|
||||
}
|
||||
// NEW CODE - end
|
||||
|
||||
score = TransformLMScore(score);
|
||||
|
||||
float score10 = ruleScore.Finish();
|
||||
float score = TransformLMScore(score10);
|
||||
|
||||
/*
|
||||
stringstream strme;
|
||||
hypo.OutputToStream(strme);
|
||||
cerr << "HELLO " << score10 << " " << score << " " << strme.str() << " " << hypo.GetBitmap() << endl;
|
||||
*/
|
||||
|
||||
bool OOVFeatureEnabled = false;
|
||||
if (OOVFeatureEnabled) {
|
||||
|
@ -173,7 +173,7 @@ void Manager::OutputBest() const
|
||||
}
|
||||
|
||||
bestHypo->OutputToStream(out);
|
||||
//cerr << "BEST TRANSLATION: " << *bestHypo;
|
||||
cerr << "BEST TRANSLATION: " << *bestHypo;
|
||||
}
|
||||
else {
|
||||
if (system.outputHypoScore) {
|
||||
@ -184,7 +184,7 @@ void Manager::OutputBest() const
|
||||
out << "\n";
|
||||
|
||||
system.bestCollector.Write(m_input->GetTranslationId(), out.str());
|
||||
//cerr << endl;
|
||||
cerr << endl;
|
||||
|
||||
|
||||
}
|
||||
|
@ -1,60 +1,27 @@
|
||||
#include <string>
|
||||
#include <boost/program_options.hpp>
|
||||
#include "util/usage.hh"
|
||||
#include "moses/TranslationModel/ProbingPT/storing.hh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
string inPath, outPath;
|
||||
int num_scores = 4;
|
||||
int num_lex_scores = 0;
|
||||
bool log_prob = false;
|
||||
int max_cache_size = 50000;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
po::options_description desc("Options");
|
||||
desc.add_options()
|
||||
("help", "Print help messages")
|
||||
("input-pt", po::value<string>()->required(), "Text pt")
|
||||
("output-dir", po::value<string>()->required(), "Directory when binary files will be written")
|
||||
("num-scores", po::value<int>()->default_value(num_scores), "Number of pt scores")
|
||||
("num-lex-scores", po::value<int>()->default_value(num_lex_scores), "Number of lexicalized reordering scores")
|
||||
("log-prob", "log (and floor) probabilities before storing")
|
||||
("max-cache-size", po::value<int>()->default_value(max_cache_size), "Maximum number of high-count source lines to write to cache file. 0=no cache, negative=no limit")
|
||||
const char * is_reordering = "false";
|
||||
|
||||
;
|
||||
|
||||
po::variables_map vm;
|
||||
try {
|
||||
po::store(po::parse_command_line(argc, argv, desc),
|
||||
vm); // can throw
|
||||
|
||||
/** --help option
|
||||
*/
|
||||
if ( vm.count("help")) {
|
||||
std::cout << desc << std::endl;
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
po::notify(vm); // throws on error, so do after help in case
|
||||
// there are any problems
|
||||
} catch(po::error& e) {
|
||||
std::cerr << "ERROR: " << e.what() << std::endl << std::endl;
|
||||
std::cerr << desc << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
if (!(argc == 5 || argc == 4)) {
|
||||
// Tell the user how to run the program
|
||||
std::cerr << "Provided " << argc << " arguments, needed 4 or 5." << std::endl;
|
||||
std::cerr << "Usage: " << argv[0] << " path_to_phrasetable output_dir num_scores is_reordering" << std::endl;
|
||||
std::cerr << "is_reordering should be either true or false, but it is currently a stub feature." << std::endl;
|
||||
//std::cerr << "Usage: " << argv[0] << " path_to_phrasetable number_of_uniq_lines output_bin_file output_hash_table output_vocab_id" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (vm.count("input-pt")) inPath = vm["input-pt"].as<string>();
|
||||
if (vm.count("output-dir")) outPath = vm["output-dir"].as<string>();
|
||||
if (vm.count("num-scores")) num_scores = vm["num-scores"].as<int>();
|
||||
if (vm.count("num-lex-scores")) num_lex_scores = vm["num-lex-scores"].as<int>();
|
||||
if (vm.count("max-cache-size")) max_cache_size = vm["max-cache-size"].as<int>();
|
||||
if (vm.count("log-prob")) log_prob = true;
|
||||
if (argc == 5) {
|
||||
is_reordering = argv[4];
|
||||
}
|
||||
|
||||
|
||||
createProbingPT(inPath.c_str(), outPath.c_str(), num_scores, num_lex_scores, log_prob, max_cache_size);
|
||||
createProbingPT(argv[1], argv[2], argv[3], is_reordering);
|
||||
|
||||
util::PrintUsage(std::cout);
|
||||
return 0;
|
||||
|
@ -1,7 +1,4 @@
|
||||
#include "huffmanish.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
Huffman::Huffman (const char * filepath)
|
||||
{
|
||||
@ -141,19 +138,12 @@ void Huffman::serialize_maps(const char * dirname)
|
||||
os2.close();
|
||||
}
|
||||
|
||||
std::vector<unsigned char> Huffman::full_encode_line(line_text &line, bool log_prob)
|
||||
std::vector<unsigned char> Huffman::full_encode_line(line_text line)
|
||||
{
|
||||
return vbyte_encode_line((encode_line(line, log_prob)));
|
||||
return vbyte_encode_line((encode_line(line)));
|
||||
}
|
||||
|
||||
//! make sure score doesn't fall below LOWEST_SCORE
|
||||
inline float FloorScore(float logScore)
|
||||
{
|
||||
const float LOWEST_SCORE = -100.0f;
|
||||
return (std::max)(logScore , LOWEST_SCORE);
|
||||
}
|
||||
|
||||
std::vector<unsigned int> Huffman::encode_line(line_text &line, bool log_prob)
|
||||
std::vector<unsigned int> Huffman::encode_line(line_text line)
|
||||
{
|
||||
std::vector<unsigned int> retvector;
|
||||
|
||||
@ -172,18 +162,9 @@ std::vector<unsigned int> Huffman::encode_line(line_text &line, bool log_prob)
|
||||
//Sometimes we have too big floats to handle, so first convert to double
|
||||
double tempnum = atof(probit->data());
|
||||
float num = (float)tempnum;
|
||||
if (log_prob) {
|
||||
num = FloorScore(log(num));
|
||||
if (num == 0.0f) num = 0.0000000001;
|
||||
}
|
||||
//cerr << "num=" << num << endl;
|
||||
retvector.push_back(reinterpret_float(&num));
|
||||
probit++;
|
||||
}
|
||||
|
||||
// append LexRO prob to pt scores
|
||||
AppendLexRO(line, retvector, log_prob);
|
||||
|
||||
//Add a zero;
|
||||
retvector.push_back(0);
|
||||
|
||||
@ -192,72 +173,9 @@ std::vector<unsigned int> Huffman::encode_line(line_text &line, bool log_prob)
|
||||
retvector.push_back(word_all1_huffman.find(splitWordAll1(line.word_align))->second);
|
||||
retvector.push_back(0);
|
||||
|
||||
//The rest of the components might not be there, but add them (as reinterpretation to byte arr)
|
||||
//In the future we should really make those optional to save space
|
||||
|
||||
//Counts
|
||||
const char* counts = line.counts.data();
|
||||
size_t counts_size = line.counts.size();
|
||||
for (size_t i = 0; i < counts_size; i++) {
|
||||
retvector.push_back(counts[i]);
|
||||
}
|
||||
retvector.push_back(0);
|
||||
|
||||
//Sparse score
|
||||
const char* sparse_score = line.sparse_score.data();
|
||||
size_t sparse_score_size = line.sparse_score.size();
|
||||
for (size_t i = 0; i < sparse_score_size; i++) {
|
||||
retvector.push_back(sparse_score[i]);
|
||||
}
|
||||
retvector.push_back(0);
|
||||
|
||||
//Property
|
||||
const char* property = line.property_to_be_binarized.data();
|
||||
size_t property_size = line.property_to_be_binarized.size();
|
||||
for (size_t i = 0; i < property_size; i++) {
|
||||
retvector.push_back(property[i]);
|
||||
}
|
||||
retvector.push_back(0);
|
||||
|
||||
return retvector;
|
||||
}
|
||||
|
||||
void Huffman::AppendLexRO(line_text &line, std::vector<unsigned int> &retvector, bool log_prob)
|
||||
{
|
||||
const StringPiece &origProperty = line.property_orig;
|
||||
StringPiece::size_type startPos = origProperty.find("{{LexRO ");
|
||||
|
||||
if (startPos != StringPiece::npos) {
|
||||
StringPiece::size_type endPos = origProperty.find("}}", startPos + 8);
|
||||
StringPiece lexProb = origProperty.substr(startPos + 8, endPos - startPos - 8);
|
||||
//cerr << "lexProb=" << lexProb << endl;
|
||||
|
||||
// append lex probs to pt probs
|
||||
util::TokenIter<util::SingleCharacter> it(lexProb, util::SingleCharacter(' '));
|
||||
while (it) {
|
||||
StringPiece probStr = *it;
|
||||
//cerr << "\t" << probStr << endl;
|
||||
|
||||
double tempnum = atof(probStr.data());
|
||||
float num = (float)tempnum;
|
||||
if (log_prob) {
|
||||
num = FloorScore(log(num));
|
||||
if (num == 0.0f) num = 0.0000000001;
|
||||
}
|
||||
|
||||
retvector.push_back(reinterpret_float(&num));
|
||||
|
||||
// exclude LexRO property from property column
|
||||
line.property_to_be_binarized = origProperty.substr(0, startPos).as_string()
|
||||
+ origProperty.substr(endPos + 2, origProperty.size() - endPos - 2).as_string();
|
||||
//cerr << "line.property_to_be_binarized=" << line.property_to_be_binarized << "AAAA" << endl;
|
||||
it++;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Huffman::produce_lookups()
|
||||
{
|
||||
//basically invert every map that we have
|
||||
@ -308,7 +226,7 @@ std::vector<target_text> HuffmanDecoder::full_decode_line (std::vector<unsigned
|
||||
std::vector<unsigned int>::iterator it = decoded_lines.begin(); //Iterator for them
|
||||
std::vector<unsigned int> current_target_phrase; //Current target phrase decoded
|
||||
|
||||
short zero_count = 0; //Count how many zeroes we have met. so far. Every 6 zeroes mean a new target phrase.
|
||||
short zero_count = 0; //Count home many zeroes we have met. so far. Every 3 zeroes mean a new target phrase.
|
||||
while(it != decoded_lines.end()) {
|
||||
if (zero_count == 1) {
|
||||
//We are extracting scores. we know how many scores there are so we can push them
|
||||
@ -320,7 +238,7 @@ std::vector<target_text> HuffmanDecoder::full_decode_line (std::vector<unsigned
|
||||
}
|
||||
}
|
||||
|
||||
if (zero_count == 6) {
|
||||
if (zero_count == 3) {
|
||||
//We have finished with this entry, decode it, and add it to the retvector.
|
||||
retvector.push_back(decode_line(current_target_phrase, num_scores));
|
||||
current_target_phrase.clear(); //Clear the current target phrase and the zero_count
|
||||
@ -334,7 +252,7 @@ std::vector<target_text> HuffmanDecoder::full_decode_line (std::vector<unsigned
|
||||
it++; //Go to the next word/symbol
|
||||
}
|
||||
//Don't forget the last remaining line!
|
||||
if (zero_count == 6) {
|
||||
if (zero_count == 3) {
|
||||
//We have finished with this entry, decode it, and add it to the retvector.
|
||||
retvector.push_back(decode_line(current_target_phrase, num_scores));
|
||||
current_target_phrase.clear(); //Clear the current target phrase and the zero_count
|
||||
@ -357,7 +275,7 @@ target_text HuffmanDecoder::decode_line (std::vector<unsigned int> input, int nu
|
||||
//Split the line into the proper arrays
|
||||
short num_zeroes = 0;
|
||||
int counter = 0;
|
||||
while (num_zeroes < 6) {
|
||||
while (num_zeroes < 3) {
|
||||
unsigned int num = input[counter];
|
||||
if (num == 0) {
|
||||
num_zeroes++;
|
||||
@ -373,12 +291,6 @@ target_text HuffmanDecoder::decode_line (std::vector<unsigned int> input, int nu
|
||||
continue;
|
||||
} else if (num_zeroes == 2) {
|
||||
wAll = num;
|
||||
} else if (num_zeroes == 3) {
|
||||
ret.counts.push_back(static_cast<char>(input[counter]));
|
||||
} else if (num_zeroes == 4) {
|
||||
ret.sparse_score.push_back(static_cast<char>(input[counter]));
|
||||
} else if (num_zeroes == 5) {
|
||||
ret.property.push_back(static_cast<char>(input[counter]));
|
||||
}
|
||||
counter++;
|
||||
}
|
||||
|
@ -53,10 +53,10 @@ public:
|
||||
void serialize_maps(const char * dirname);
|
||||
void produce_lookups();
|
||||
|
||||
std::vector<unsigned int> encode_line(line_text &line, bool log_prob);
|
||||
std::vector<unsigned int> encode_line(line_text line);
|
||||
|
||||
//encode line + variable byte ontop
|
||||
std::vector<unsigned char> full_encode_line(line_text &line, bool log_prob);
|
||||
std::vector<unsigned char> full_encode_line(line_text line);
|
||||
|
||||
//Getters
|
||||
const std::map<unsigned int, std::string> get_target_lookup_map() const {
|
||||
@ -69,9 +69,6 @@ public:
|
||||
unsigned long getUniqLines() {
|
||||
return uniq_lines;
|
||||
}
|
||||
|
||||
void AppendLexRO(line_text &line, std::vector<unsigned int> &retvector, bool log_prob);
|
||||
|
||||
};
|
||||
|
||||
class HuffmanDecoder
|
||||
|
@ -2,48 +2,41 @@
|
||||
|
||||
line_text splitLine(StringPiece textin)
|
||||
{
|
||||
const char delim[] = "|||";
|
||||
const char delim[] = " ||| ";
|
||||
line_text output;
|
||||
|
||||
//Tokenize
|
||||
util::TokenIter<util::MultiCharacter> it(textin, util::MultiCharacter(delim));
|
||||
//Get source phrase
|
||||
output.source_phrase = Trim(*it);
|
||||
//std::cerr << "output.source_phrase=" << output.source_phrase << "AAAA" << std::endl;
|
||||
output.source_phrase = *it;
|
||||
|
||||
//Get target_phrase
|
||||
it++;
|
||||
output.target_phrase = Trim(*it);
|
||||
//std::cerr << "output.target_phrase=" << output.target_phrase << "AAAA" << std::endl;
|
||||
output.target_phrase = *it;
|
||||
|
||||
//Get probabilities
|
||||
it++;
|
||||
output.prob = Trim(*it);
|
||||
//std::cerr << "output.prob=" << output.prob << "AAAA" << std::endl;
|
||||
output.prob = *it;
|
||||
|
||||
//Get WordAllignment
|
||||
it++;
|
||||
if (it == util::TokenIter<util::MultiCharacter>::end()) return output;
|
||||
output.word_align = Trim(*it);
|
||||
//std::cerr << "output.word_align=" << output.word_align << "AAAA" << std::endl;
|
||||
output.word_align = *it;
|
||||
|
||||
//Get count
|
||||
it++;
|
||||
if (it == util::TokenIter<util::MultiCharacter>::end()) return output;
|
||||
output.counts = Trim(*it);
|
||||
//std::cerr << "output.counts=" << output.counts << "AAAA" << std::endl;
|
||||
output.counts = *it;
|
||||
|
||||
//Get sparse_score
|
||||
it++;
|
||||
if (it == util::TokenIter<util::MultiCharacter>::end()) return output;
|
||||
output.sparse_score = Trim(*it);
|
||||
//std::cerr << "output.sparse_score=" << output.sparse_score << "AAAA" << std::endl;
|
||||
output.sparse_score = *it;
|
||||
|
||||
//Get property
|
||||
it++;
|
||||
if (it == util::TokenIter<util::MultiCharacter>::end()) return output;
|
||||
output.property_orig = Trim(*it);
|
||||
//std::cerr << "output.property=" << output.property << "AAAA" << std::endl;
|
||||
output.property = *it;
|
||||
|
||||
return output;
|
||||
}
|
||||
@ -54,11 +47,6 @@ std::vector<unsigned char> splitWordAll1(StringPiece textin)
|
||||
const char delim2[] = "-";
|
||||
std::vector<unsigned char> output;
|
||||
|
||||
//Case with no word alignments.
|
||||
if (textin.size() == 0) {
|
||||
return output;
|
||||
}
|
||||
|
||||
//Split on space
|
||||
util::TokenIter<util::MultiCharacter> it(textin, util::MultiCharacter(delim));
|
||||
|
||||
|
@ -17,8 +17,7 @@ struct line_text {
|
||||
StringPiece word_align;
|
||||
StringPiece counts;
|
||||
StringPiece sparse_score;
|
||||
StringPiece property_orig;
|
||||
std::string property_to_be_binarized;
|
||||
StringPiece property;
|
||||
};
|
||||
|
||||
//Struct for holding processed line
|
||||
@ -26,9 +25,6 @@ struct target_text {
|
||||
std::vector<unsigned int> target_phrase;
|
||||
std::vector<float> prob;
|
||||
std::vector<unsigned char> word_all1;
|
||||
std::vector<char> counts;
|
||||
std::vector<char> sparse_score;
|
||||
std::vector<char> property;
|
||||
};
|
||||
|
||||
//Ask if it's better to have it receive a pointer to a line_text struct
|
||||
|
@ -25,9 +25,9 @@ char * readTable(const char * filename, size_t size)
|
||||
}
|
||||
|
||||
|
||||
void serialize_table(char *mem, size_t size, const std::string &filename)
|
||||
void serialize_table(char *mem, size_t size, const char * filename)
|
||||
{
|
||||
std::ofstream os (filename.c_str(), std::ios::binary);
|
||||
std::ofstream os (filename, std::ios::binary);
|
||||
os.write((const char*)&mem[0], size);
|
||||
os.close();
|
||||
|
||||
|
@ -7,8 +7,6 @@
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
|
||||
#define API_VERSION 7
|
||||
|
||||
|
||||
//Hash table entry
|
||||
struct Entry {
|
||||
@ -34,6 +32,6 @@ struct Entry {
|
||||
//Define table
|
||||
typedef util::ProbingHashTable<Entry, boost::hash<uint64_t> > Table;
|
||||
|
||||
void serialize_table(char *mem, size_t size, const std::string &filename);
|
||||
void serialize_table(char *mem, size_t size, const char * filename);
|
||||
|
||||
char * readTable(const char * filename, size_t size);
|
||||
|
@ -43,9 +43,8 @@ QueryEngine::QueryEngine(const char * filepath) : decoder(filepath)
|
||||
std::ifstream config ((basepath + "/config").c_str());
|
||||
//Check API version:
|
||||
getline(config, line);
|
||||
int version = atoi(line.c_str());
|
||||
if (version != API_VERSION) {
|
||||
std::cerr << "The ProbingPT API has changed. " << version << "!=" << API_VERSION << " Please rebinarize your phrase tables." << std::endl;
|
||||
if (atoi(line.c_str()) != API_VERSION) {
|
||||
std::cerr << "The ProbingPT API has changed, please rebinarize your phrase tables." << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
//Get tablesize.
|
||||
@ -194,25 +193,6 @@ void QueryEngine::printTargetInfo(std::vector<target_text> target_phrases)
|
||||
std::cout << (short)target_phrases[i].word_all1[j] << " ";
|
||||
}
|
||||
}
|
||||
|
||||
//Print counts
|
||||
for (size_t j = 0; j < target_phrases[i].counts.size(); j++) {
|
||||
std::cout << target_phrases[i].counts[j];
|
||||
}
|
||||
std::cout << "\t";
|
||||
|
||||
//Print sparse_score
|
||||
for (size_t j = 0; j < target_phrases[i].sparse_score.size(); j++) {
|
||||
std::cout << target_phrases[i].sparse_score[j];
|
||||
}
|
||||
std::cout << "\t";
|
||||
|
||||
//Print properties
|
||||
for (size_t j = 0; j < target_phrases[i].property.size(); j++) {
|
||||
std::cout << target_phrases[i].property[j];
|
||||
}
|
||||
std::cout << "\t";
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <sys/stat.h> //For finding size of file
|
||||
#include "vocabid.hh"
|
||||
#include <algorithm> //toLower
|
||||
#define API_VERSION 3
|
||||
|
||||
|
||||
char * read_binary_file(char * filename);
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include "storing.hh"
|
||||
#include "moses/Util.h"
|
||||
|
||||
BinaryFileWriter::BinaryFileWriter (std::string basepath) : os ((basepath + "/binfile.dat").c_str(), std::ios::binary)
|
||||
{
|
||||
@ -40,7 +39,7 @@ BinaryFileWriter::~BinaryFileWriter ()
|
||||
}
|
||||
|
||||
void createProbingPT(const char * phrasetable_path, const char * target_path,
|
||||
int num_scores, int num_lex_scores, bool log_prob, int max_cache_size)
|
||||
const char * num_scores, const char * is_reordering)
|
||||
{
|
||||
//Get basepath and create directory if missing
|
||||
std::string basepath(target_path);
|
||||
@ -69,9 +68,6 @@ void createProbingPT(const char * phrasetable_path, const char * target_path,
|
||||
|
||||
BinaryFileWriter binfile(basepath); //Init the binary file writer.
|
||||
|
||||
std::priority_queue<CacheItem*, std::vector<CacheItem*>, CacheItemOrderer> cache;
|
||||
float totalSourceCount = 0;
|
||||
|
||||
line_text prev_line; //Check if the source phrase of the previous line is the same
|
||||
|
||||
//Keep track of the size of each group of target phrases
|
||||
@ -114,34 +110,15 @@ void createProbingPT(const char * phrasetable_path, const char * target_path,
|
||||
entrystartidx = binfile.dist_from_start + binfile.extra_counter; //Designate start idx for new entry
|
||||
|
||||
//Encode a line and write it to disk.
|
||||
std::vector<unsigned char> encoded_line = huffmanEncoder.full_encode_line(line, log_prob);
|
||||
std::vector<unsigned char> encoded_line = huffmanEncoder.full_encode_line(line);
|
||||
binfile.write(&encoded_line);
|
||||
|
||||
// update cache
|
||||
if (max_cache_size) {
|
||||
std::string countStr = line.counts.as_string();
|
||||
countStr = Moses::Trim(countStr);
|
||||
if (!countStr.empty()) {
|
||||
std::vector<float> toks = Moses::Tokenize<float>(countStr);
|
||||
|
||||
if (toks.size() >= 2) {
|
||||
totalSourceCount += toks[1];
|
||||
CacheItem *item = new CacheItem(Moses::Trim(line.source_phrase.as_string()), toks[1]);
|
||||
cache.push(item);
|
||||
|
||||
if (max_cache_size > 0 && cache.size() > max_cache_size) {
|
||||
cache.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//Set prevLine
|
||||
prev_line = line;
|
||||
|
||||
} else {
|
||||
//If we still have the same line, just append to it:
|
||||
std::vector<unsigned char> encoded_line = huffmanEncoder.full_encode_line(line, log_prob);
|
||||
std::vector<unsigned char> encoded_line = huffmanEncoder.full_encode_line(line);
|
||||
binfile.write(&encoded_line);
|
||||
}
|
||||
|
||||
@ -167,11 +144,9 @@ void createProbingPT(const char * phrasetable_path, const char * target_path,
|
||||
}
|
||||
}
|
||||
|
||||
serialize_table(mem, size, (basepath + "/probing_hash.dat"));
|
||||
serialize_table(mem, size, (basepath + "/probing_hash.dat").c_str());
|
||||
|
||||
serialize_map(&source_vocabids, (basepath + "/source_vocabids"));
|
||||
|
||||
serialize_cache(cache, (basepath + "/cache"), totalSourceCount);
|
||||
serialize_map(&source_vocabids, (basepath + "/source_vocabids").c_str());
|
||||
|
||||
delete[] mem;
|
||||
|
||||
@ -181,34 +156,6 @@ void createProbingPT(const char * phrasetable_path, const char * target_path,
|
||||
configfile << API_VERSION << '\n';
|
||||
configfile << uniq_entries << '\n';
|
||||
configfile << num_scores << '\n';
|
||||
configfile << num_lex_scores << '\n';
|
||||
configfile << log_prob << '\n';
|
||||
configfile << is_reordering << '\n';
|
||||
configfile.close();
|
||||
}
|
||||
|
||||
void serialize_cache(std::priority_queue<CacheItem*, std::vector<CacheItem*>, CacheItemOrderer> &cache,
|
||||
const std::string &path,
|
||||
float totalSourceCount)
|
||||
{
|
||||
std::vector<const CacheItem*> vec(cache.size());
|
||||
|
||||
size_t ind = cache.size() - 1;
|
||||
while (!cache.empty()) {
|
||||
const CacheItem *item = cache.top();
|
||||
vec[ind] = item;
|
||||
cache.pop();
|
||||
--ind;
|
||||
}
|
||||
|
||||
std::ofstream os (path.c_str());
|
||||
|
||||
os << totalSourceCount << std::endl;
|
||||
for (size_t i = 0; i < vec.size(); ++i) {
|
||||
const CacheItem *item = vec[i];
|
||||
os << item->count << "\t" << item->source << std::endl;
|
||||
delete item;
|
||||
}
|
||||
|
||||
os.close();
|
||||
}
|
||||
|
||||
|
@ -3,24 +3,19 @@
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <queue>
|
||||
#include <sys/stat.h> //mkdir
|
||||
|
||||
#include "hash.hh" //Includes line_splitter
|
||||
#include "probing_hash_utils.hh"
|
||||
#include "huffmanish.hh"
|
||||
#include <sys/stat.h> //mkdir
|
||||
|
||||
#include "util/file_piece.hh"
|
||||
#include "util/file.hh"
|
||||
#include "vocabid.hh"
|
||||
#define API_VERSION 3
|
||||
|
||||
void createProbingPT(const char * phrasetable_path,
|
||||
const char * target_path,
|
||||
int num_scores,
|
||||
int num_lex_scores,
|
||||
bool log_prob,
|
||||
int max_cache_size);
|
||||
void createProbingPT(const char * phrasetable_path, const char * target_path,
|
||||
const char * num_scores, const char * is_reordering);
|
||||
|
||||
class BinaryFileWriter
|
||||
{
|
||||
@ -39,31 +34,3 @@ public:
|
||||
void flush (); //Flush to disk
|
||||
|
||||
};
|
||||
|
||||
class CacheItem
|
||||
{
|
||||
public:
|
||||
std::string source;
|
||||
float count;
|
||||
CacheItem(const std::string &source, float count)
|
||||
:source(source)
|
||||
,count(count)
|
||||
{}
|
||||
|
||||
bool operator<(const CacheItem &other) const
|
||||
{
|
||||
return count > other.count;
|
||||
}
|
||||
};
|
||||
|
||||
class CacheItemOrderer
|
||||
{
|
||||
public:
|
||||
bool operator()(const CacheItem* a, const CacheItem* b) const {
|
||||
return (*a) < (*b);
|
||||
}
|
||||
};
|
||||
|
||||
void serialize_cache(std::priority_queue<CacheItem*, std::vector<CacheItem*>, CacheItemOrderer> &cache,
|
||||
const std::string &path,
|
||||
float totalSourceCount);
|
||||
|
@ -11,9 +11,9 @@ void add_to_map(std::map<uint64_t, std::string> *karta, StringPiece textin)
|
||||
}
|
||||
}
|
||||
|
||||
void serialize_map(std::map<uint64_t, std::string> *karta, const std::string &filename)
|
||||
void serialize_map(std::map<uint64_t, std::string> *karta, const char* filename)
|
||||
{
|
||||
std::ofstream os (filename.c_str(), std::ios::binary);
|
||||
std::ofstream os (filename, std::ios::binary);
|
||||
boost::archive::text_oarchive oarch(os);
|
||||
|
||||
oarch << *karta; //Serialise map
|
||||
|
@ -15,6 +15,6 @@
|
||||
|
||||
void add_to_map(std::map<uint64_t, std::string> *karta, StringPiece textin);
|
||||
|
||||
void serialize_map(std::map<uint64_t, std::string> *karta, const std::string &filename);
|
||||
void serialize_map(std::map<uint64_t, std::string> *karta, const char* filename);
|
||||
|
||||
void read_map(std::map<uint64_t, std::string> *karta, const char* filename);
|
||||
|
Loading…
Reference in New Issue
Block a user