mirror of
synced 2024-12-29 15:04:05 +03:00
273 lines
7.6 KiB
273 lines
7.6 KiB
// $Id$
Moses - factored phrase-based, hierarchical and syntactic language decoder
Copyright (C) 2009 Hieu Hoang
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#include "util/check.hh"
#include "PhraseNode.h"
#include "OnDiskWrapper.h"
#include "TargetPhraseCollection.h"
#include "SourcePhrase.h"
#include "../moses/src/Util.h"
using namespace std;
namespace OnDiskPt
size_t PhraseNode::GetNodeSize(size_t numChildren, size_t wordSize, size_t countSize)
size_t ret = sizeof(UINT64) * 2 // num children, value
+ (wordSize + sizeof(UINT64)) * numChildren // word + ptr to next source node
+ sizeof(float) * countSize; // count info
return ret;
PhraseNode::PhraseNode(UINT64 filePos, OnDiskWrapper &onDiskWrapper)
// load saved node
m_filePos = filePos;
size_t countSize = onDiskWrapper.GetNumCounts();
std::fstream &file = onDiskWrapper.GetFileSource();
CHECK(filePos == file.tellg());
file.read((char*) &m_numChildrenLoad, sizeof(UINT64));
size_t memAlloc = GetNodeSize(m_numChildrenLoad, onDiskWrapper.GetSourceWordSize(), countSize);
m_memLoad = (char*) malloc(memAlloc);
// go to start of node again
CHECK(filePos == file.tellg());
// read everything into memory
file.read(m_memLoad, memAlloc);
CHECK(filePos + memAlloc == file.tellg());
// get value
m_value = ((UINT64*)m_memLoad)[1];
// get counts
float *memFloat = (float*) (m_memLoad + sizeof(UINT64) * 2);
CHECK(countSize == 1);
m_counts[0] = memFloat[0];
m_memLoadLast = m_memLoad + memAlloc;
float PhraseNode::GetCount(size_t ind) const
return m_counts[ind];
void PhraseNode::Save(OnDiskWrapper &onDiskWrapper, size_t pos, size_t tableLimit)
// save this node
m_value = m_targetPhraseColl.GetFilePos();
size_t numCounts = onDiskWrapper.GetNumCounts();
size_t memAlloc = GetNodeSize(GetSize(), onDiskWrapper.GetSourceWordSize(), numCounts);
char *mem = (char*) malloc(memAlloc);
//memset(mem, 0xfe, memAlloc);
size_t memUsed = 0;
UINT64 *memArray = (UINT64*) mem;
memArray[0] = GetSize(); // num of children
memArray[1] = m_value; // file pos of corresponding target phrases
memUsed += 2 * sizeof(UINT64);
// count info
float *memFloat = (float*) (mem + memUsed);
CHECK(numCounts == 1);
memFloat[0] = (m_counts.size() == 0) ? DEFAULT_COUNT : m_counts[0]; // if count = 0, put in very large num to make sure its still used. HACK
memUsed += sizeof(float) * numCounts;
// recursively save chm_countsildren
ChildColl::iterator iter;
for (iter = m_children.begin(); iter != m_children.end(); ++iter) {
const Word &childWord = iter->first;
PhraseNode &childNode = iter->second;
// recursive
if (!childNode.Saved())
childNode.Save(onDiskWrapper, pos + 1, tableLimit);
char *currMem = mem + memUsed;
size_t wordMemUsed = childWord.WriteToMemory(currMem);
memUsed += wordMemUsed;
UINT64 *memArray = (UINT64*) (mem + memUsed);
memArray[0] = childNode.GetFilePos();
memUsed += sizeof(UINT64);
// save this node
//Moses::DebugMem(mem, memAlloc);
CHECK(memUsed == memAlloc);
std::fstream &file = onDiskWrapper.GetFileSource();
m_filePos = file.tellp();
file.seekp(0, ios::end);
file.write(mem, memUsed);
UINT64 endPos = file.tellp();
CHECK(m_filePos + memUsed == endPos);
m_saved = true;
void PhraseNode::AddTargetPhrase(const SourcePhrase &sourcePhrase, TargetPhrase *targetPhrase
, OnDiskWrapper &onDiskWrapper, size_t tableLimit
, const std::vector<float> &counts)
AddTargetPhrase(0, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts);
void PhraseNode::AddTargetPhrase(size_t pos, const SourcePhrase &sourcePhrase
, TargetPhrase *targetPhrase, OnDiskWrapper &onDiskWrapper
, size_t tableLimit, const std::vector<float> &counts)
size_t phraseSize = sourcePhrase.GetSize();
if (pos < phraseSize) {
const Word &word = sourcePhrase.GetWord(pos);
PhraseNode &node = m_children[word];
if (m_currChild != &node) {
// new node
if (m_currChild) {
m_currChild->Save(onDiskWrapper, pos, tableLimit);
m_currChild = &node;
node.AddTargetPhrase(pos + 1, sourcePhrase, targetPhrase, onDiskWrapper, tableLimit, counts);
} else {
// drilled down to the right node
m_counts = counts;
const PhraseNode *PhraseNode::GetChild(const Word &wordSought, OnDiskWrapper &onDiskWrapper) const
const PhraseNode *ret = NULL;
int l = 0;
int r = m_numChildrenLoad - 1;
int x;
while (r >= l) {
x = (l + r) / 2;
Word wordFound;
UINT64 childFilePos;
GetChild(wordFound, childFilePos, x, onDiskWrapper);
if (wordSought == wordFound) {
ret = new PhraseNode(childFilePos, onDiskWrapper);
if (wordSought < wordFound)
r = x - 1;
l = x + 1;
return ret;
void PhraseNode::GetChild(Word &wordFound, UINT64 &childFilePos, size_t ind, OnDiskWrapper &onDiskWrapper) const
size_t wordSize = onDiskWrapper.GetSourceWordSize();
size_t childSize = wordSize + sizeof(UINT64);
size_t numFactors = onDiskWrapper.GetNumSourceFactors();
char *currMem = m_memLoad
+ sizeof(UINT64) * 2 // size & file pos of target phrase coll
+ sizeof(float) * onDiskWrapper.GetNumCounts() // count info
+ childSize * ind;
size_t memRead = ReadChild(wordFound, childFilePos, currMem, numFactors);
CHECK(memRead == childSize);
size_t PhraseNode::ReadChild(Word &wordFound, UINT64 &childFilePos, const char *mem, size_t numFactors) const
size_t memRead = wordFound.ReadFromMemory(mem, numFactors);
const char *currMem = mem + memRead;
UINT64 *memArray = (UINT64*) (currMem);
childFilePos = memArray[0];
memRead += sizeof(UINT64);
return memRead;
const TargetPhraseCollection *PhraseNode::GetTargetPhraseCollection(size_t tableLimit, OnDiskWrapper &onDiskWrapper) const
TargetPhraseCollection *ret = new TargetPhraseCollection();
if (m_value > 0)
ret->ReadFromFile(tableLimit, m_value, onDiskWrapper);
else {
return ret;
std::ostream& operator<<(std::ostream &out, const PhraseNode &node)
out << "node (" << node.GetFilePos() << "," << node.GetValue() << "," << node.m_pos << ")";
return out;