Merge ../mosesdecoder into hieu

This commit is contained in:
Hieu Hoang 2014-06-05 17:18:26 +01:00
commit ce2a69ba25
6 changed files with 72 additions and 35 deletions

View File

@ -13,14 +13,16 @@
#ifndef LM_BHIKSHA_H
#define LM_BHIKSHA_H
#include <stdint.h>
#include <assert.h>
#include "lm/model_type.hh"
#include "lm/trie.hh"
#include "util/bit_packing.hh"
#include "util/sorted_uniform.hh"
#include <algorithm>
#include <stdint.h>
#include <assert.h>
namespace lm {
namespace ngram {
struct Config;
@ -73,15 +75,24 @@ class ArrayBhiksha {
ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config);
void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const {
const uint64_t *begin_it = util::BinaryBelow(util::IdentityAccessor<uint64_t>(), offset_begin_, offset_end_, index);
// Some assertions are commented out because they are expensive.
// assert(*offset_begin_ == 0);
// std::upper_bound returns the first element that is greater. Want the
// last element that is <= to the index.
const uint64_t *begin_it = std::upper_bound(offset_begin_, offset_end_, index) - 1;
// Since *offset_begin_ == 0, the position should be in range.
// assert(begin_it >= offset_begin_);
const uint64_t *end_it;
for (end_it = begin_it; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {}
for (end_it = begin_it + 1; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {}
// assert(end_it == std::upper_bound(offset_begin_, offset_end_, index + 1));
--end_it;
// assert(end_it >= begin_it);
out.begin = ((begin_it - offset_begin_) << next_inline_.bits) |
util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask);
out.end = ((end_it - offset_begin_) << next_inline_.bits) |
util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask);
//assert(out.end >= out.begin);
// If this fails, consider rebuilding your model using KenLM after 1e333d786b748555e8f368d2bbba29a016c98052
assert(out.end >= out.begin);
}
void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) {

View File

@ -99,8 +99,11 @@ template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordInd
}
template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
uint64_t last_next_write = (insert_index_ + 1) * total_bits_ - bhiksha_.InlineBits();
bhiksha_.WriteNext(base_, last_next_write, insert_index_ + 1, next_end);
// Write at insert_index. . .
uint64_t last_next_write = insert_index_ * total_bits_ +
// at the offset where the next pointers are stored.
(total_bits_ - bhiksha_.InlineBits());
bhiksha_.WriteNext(base_, last_next_write, insert_index_, next_end);
bhiksha_.FinishedLoading(config);
}

View File

@ -378,6 +378,7 @@ FFState *LanguageModelDALM::EvaluateChart(const ChartHypothesis& hypo, int featu
// copy chart state
(*newState) = (*prevState);
hypoSizeAll = hypoSize+prevState->GetHypoSize()-1;
// get hypoScore
hypoScore = UntransformLMScore(prevHypo->GetScoreBreakdown().GetScoresForProducer(this)[0]);
@ -496,6 +497,7 @@ void LanguageModelDALM::EvaluateTerminal(
hypoScore += score;
}else{
float score = m_lm->query(wid, state, prefixFragments[prefixLength]);
if(score > 0){
hypoScore -= score;
newState->SetAsLarge();
@ -538,9 +540,8 @@ void LanguageModelDALM::EvaluateNonTerminal(
state = prevState->GetRightContext();
return;
}
DALM::Gap gap(state);
// score its prefix
for(size_t prefixPos = 0; prefixPos < prevPrefixLength; prefixPos++) {
const DALM::Fragment &f = prevPrefixFragments[prefixPos];
@ -551,7 +552,7 @@ void LanguageModelDALM::EvaluateNonTerminal(
if(!gap.is_extended()){
state = prevState->GetRightContext();
return;
}else if(gap.get_count() <= prefixPos+1){
}else if(state.get_count() <= prefixPos+1){
state = prevState->GetRightContext();
return;
}
@ -559,11 +560,12 @@ void LanguageModelDALM::EvaluateNonTerminal(
DALM::Fragment &fnew = prefixFragments[prefixLength];
float score = m_lm->query(f, state, gap, fnew);
hypoScore += score;
if(!gap.is_extended()){
newState->SetAsLarge();
state = prevState->GetRightContext();
return;
}else if(gap.get_count() <= prefixPos+1){
}else if(state.get_count() <= prefixPos+1){
if(!gap.is_finalized()) prefixLength++;
newState->SetAsLarge();
state = prevState->GetRightContext();
@ -582,7 +584,7 @@ void LanguageModelDALM::EvaluateNonTerminal(
if (prevState->LargeEnough()) {
newState->SetAsLarge();
if(prevPrefixLength < prevState->GetHypoSize()){
hypoScore += state.sum_bows(prevPrefixLength, gap.get_count());
hypoScore += state.sum_bows(prevPrefixLength, state.get_count());
}
// copy language model state
state = prevState->GetRightContext();

View File

@ -149,7 +149,7 @@ bool RuleTableLoaderStandard::Load(FormatType format
, size_t /* tableLimit */
, RuleTableTrie &ruleTable)
{
PrintUserTime(string("Start loading text SCFG phrase table. ") + (format==MosesFormat?"Moses ":"Hiero ") + " format");
PrintUserTime(string("Start loading text phrase table. ") + (format==MosesFormat?"Moses ":"Hiero ") + " format");
const StaticData &staticData = StaticData::Instance();
const std::string& factorDelimiter = staticData.GetFactorDelimiter();

View File

@ -0,0 +1,42 @@
#!/usr/bin/perl -w
use strict;
my ($cluster_file,$in,$out,$tmp) = @ARGV;
my $CLUSTER = &read_cluster_from_mkcls($cluster_file);
open(IN,$in);
open(OUT,">$out");
while(<IN>) {
chop;
s/\s+/ /g;
s/^ //;
s/ $//;
my $first = 1;
foreach my $word (split) {
my $cluster = defined($$CLUSTER{$word}) ? $$CLUSTER{$word} : "<unk>";
print OUT " " unless $first;
print OUT $cluster;
$first = 0;
}
print OUT "\n";
}
close(OUT);
close(IN);
sub read_cluster_from_mkcls {
my ($file) = @_;
my %CLUSTER;
open(CLUSTER_FILE,$file);
while(<CLUSTER_FILE>) {
chop;
my ($word,$cluster) = split;
$CLUSTER{$word} = $cluster;
}
close(CLUSTER_FILE);
return \%CLUSTER;
}
sub add_cluster_to_string {
}

View File

@ -101,27 +101,6 @@ template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(co
return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);
}
// May return begin - 1.
template <class Iterator, class Accessor> Iterator BinaryBelow(
const Accessor &accessor,
Iterator begin,
Iterator end,
const typename Accessor::Key key) {
while (end > begin) {
Iterator pivot(begin + (end - begin) / 2);
typename Accessor::Key mid(accessor(pivot));
if (mid < key) {
begin = pivot + 1;
} else if (mid > key) {
end = pivot;
} else {
for (++pivot; (pivot < end) && accessor(pivot) == mid; ++pivot) {}
return pivot - 1;
}
}
return begin - 1;
}
} // namespace util
#endif // UTIL_SORTED_UNIFORM_H