Implement shuffling correctly

This commit is contained in:
Barry Haddow 2014-08-04 20:51:45 +01:00
parent e2e07940ae
commit 05455eb0c1
2 changed files with 20 additions and 21 deletions

View File

@ -173,21 +173,14 @@ HypergraphHopeFearDecoder::HypergraphHopeFearDecoder
static const string kWeights = "weights";
fs::directory_iterator dend;
size_t fileCount = 0;
vector<fs::path> hypergraphFiles;
for (fs::directory_iterator di(hypergraphDir); di != dend; ++di) {
if (di->path().filename() == kWeights) continue;
hypergraphFiles.push_back(di->path());
}
if (!no_shuffle) {
random_shuffle(hypergraphFiles.begin(), hypergraphFiles.end());
}
cerr << "Reading " << hypergraphFiles.size() << " hypergraphs" << endl;
for (vector<fs::path>::const_iterator di = hypergraphFiles.begin(); di != hypergraphFiles.end(); ++di) {
cerr << "Reading hypergraphs" << endl;
for (fs::directory_iterator di(hypergraphDir); di != dend; ++di) {
const fs::path& hgpath = di->path();
if (hgpath.filename() == kWeights) continue;
Graph graph(vocab_);
size_t id = boost::lexical_cast<size_t>(di->stem().string());
util::scoped_fd fd(util::OpenReadOrThrow(di->string().c_str()));
size_t id = boost::lexical_cast<size_t>(hgpath.stem().string());
util::scoped_fd fd(util::OpenReadOrThrow(hgpath.string().c_str()));
//util::FilePiece file(di->path().string().c_str());
util::FilePiece file(fd.release());
ReadGraph(file,graph);
@ -205,19 +198,24 @@ HypergraphHopeFearDecoder::HypergraphHopeFearDecoder
}
cerr << endl << "Done" << endl;
sentenceIds_.resize(graphs_.size());
for (size_t i = 0; i < graphs_.size(); ++i) sentenceIds_[i] = i;
if (!no_shuffle) {
random_shuffle(sentenceIds_.begin(), sentenceIds_.end());
}
}
void HypergraphHopeFearDecoder::reset() {
graphIter_ = graphs_.begin();
sentenceIdIter_ = sentenceIds_.begin();
}
void HypergraphHopeFearDecoder::next() {
++graphIter_;
sentenceIdIter_++;
}
bool HypergraphHopeFearDecoder::finished() {
return graphIter_ == graphs_.end();
return sentenceIdIter_ == sentenceIds_.end();
}
void HypergraphHopeFearDecoder::HopeFear(
@ -225,10 +223,10 @@ void HypergraphHopeFearDecoder::HopeFear(
const MiraWeightVector& wv,
HopeFearData* hopeFear
) {
size_t sentenceId = graphIter_->first;
size_t sentenceId = *sentenceIdIter_;
SparseVector weights;
wv.ToSparse(&weights);
const Graph& graph = *(graphIter_->second);
const Graph& graph = *(graphs_[sentenceId]);
ValType hope_scale = 1.0;
HgHypothesis hopeHypo, fearHypo, modelHypo;
@ -319,11 +317,11 @@ void HypergraphHopeFearDecoder::HopeFear(
void HypergraphHopeFearDecoder::MaxModel(const AvgWeightVector& wv, vector<ValType>* stats) {
assert(!finished());
HgHypothesis bestHypo;
size_t sentenceId = graphIter_->first;
size_t sentenceId = *sentenceIdIter_;
SparseVector weights;
wv.ToSparse(&weights);
vector<ValType> bg(kBleuNgramOrder*2+1);
Viterbi(*(graphIter_->second), weights, 0, references_, sentenceId, bg, &bestHypo);
Viterbi(*(graphs_[sentenceId]), weights, 0, references_, sentenceId, bg, &bestHypo);
stats->resize(bestHypo.bleuStats.size());
/*
for (size_t i = 0; i < bestHypo.text.size(); ++i) {

View File

@ -140,7 +140,8 @@ private:
//maps sentence Id to graph ptr
typedef std::map<size_t, boost::shared_ptr<Graph> > GraphColl;
GraphColl graphs_;
GraphColl::const_iterator graphIter_;
std::vector<size_t> sentenceIds_;
std::vector<size_t>::const_iterator sentenceIdIter_;
ReferenceSet references_;
Vocab vocab_;
};