added factored_vocab stubs

This commit is contained in:
Frank Seide 2019-02-05 11:45:55 -08:00
parent de69efea79
commit b7d245945f
5 changed files with 213 additions and 0 deletions

View File

@ -21,6 +21,7 @@ add_library(marian STATIC
data/vocab.cpp
data/default_vocab.cpp
data/sentencepiece_vocab.cpp
data/factored_vocab.cpp
data/corpus_base.cpp
data/corpus.cpp
data/corpus_sqlite.cpp

131
src/data/factored_vocab.cpp Executable file
View File

@ -0,0 +1,131 @@
#if 0
#include "data/vocab.h"
#include "data/vocab_base.h"
namespace marian {
Word Word::NONE = Word();
Word Word::ZERO = Word(0);
Word Word::DEFAULT_EOS_ID = Word(0);
Word Word::DEFAULT_UNK_ID = Word(1);
Ptr<IVocab> createDefaultVocab();
Ptr<IVocab> createClassVocab();
Ptr<IVocab> createSentencePieceVocab(const std::string& /*vocabPath*/, Ptr<Options>, size_t /*batchIndex*/);
// @TODO: make each vocab peek on type
Ptr<IVocab> createVocab(const std::string& vocabPath, Ptr<Options> options, size_t batchIndex) {
auto vocab = createSentencePieceVocab(vocabPath, options, batchIndex);
if(vocab) {
return vocab; // this is defined which means that a sentencepiece vocabulary could be created, so return it
} else {
// check type of input, if not given, assume "sequence"
auto inputTypes = options->get<std::vector<std::string>>("input-types", {});
std::string inputType = inputTypes.size() > batchIndex ? inputTypes[batchIndex] : "sequence";
return inputType == "class" ? createClassVocab() : createDefaultVocab();
}
}
size_t Vocab::loadOrCreate(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
size_t maxSize) {
size_t size = 0;
if(vocabPath.empty()) {
// No vocabulary path was given, attempt to first find a vocabulary
// for trainPaths[0] + possible suffixes. If not found attempt to create
// as trainPaths[0] + canonical suffix.
// Only search based on first path, maybe disable this at all?
LOG(info,
"No vocabulary path given; "
"trying to find default vocabulary based on data path {}",
trainPaths[0]);
vImpl_ = createDefaultVocab();
size = vImpl_->findAndLoad(trainPaths[0], maxSize);
if(size == 0) {
auto newVocabPath = trainPaths[0] + vImpl_->canonicalExtension();
LOG(info,
"No vocabulary path given; "
"trying to create vocabulary based on data paths {}",
utils::join(trainPaths, ", "));
create(newVocabPath, trainPaths, maxSize);
size = load(newVocabPath, maxSize);
}
} else {
if(!filesystem::exists(vocabPath)) {
// Vocabulary path was given, but no vocabulary present,
// attempt to create in specified location.
create(vocabPath, trainPaths, maxSize);
}
// Vocabulary path exists, attempting to load
size = load(vocabPath, maxSize);
}
LOG(info, "[data] Setting vocabulary size for input {} to {}", batchIndex_, size);
return size;
}
size_t Vocab::load(const std::string& vocabPath, size_t maxSize) {
if(!vImpl_)
vImpl_ = createVocab(vocabPath, options_, batchIndex_);
return vImpl_->load(vocabPath, (int)maxSize);
}
void Vocab::create(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
size_t maxSize) {
if(!vImpl_)
vImpl_ = createVocab(vocabPath, options_, batchIndex_);
vImpl_->create(vocabPath, trainPaths, maxSize);
}
void Vocab::create(const std::string& vocabPath,
const std::string& trainPath,
size_t maxSize) {
create(vocabPath, std::vector<std::string>({trainPath}), maxSize);
}
void Vocab::createFake() {
if(!vImpl_)
vImpl_ = createDefaultVocab(); // DefaultVocab is OK here
vImpl_->createFake();
}
// string token to token id
Word Vocab::operator[](const std::string& word) const {
return vImpl_->operator[](word);
}
// token id to string token
const std::string& Vocab::operator[](Word id) const {
return vImpl_->operator[](id);
}
// line of text to list of token ids, can perform tokenization
Words Vocab::encode(const std::string& line,
bool addEOS,
bool inference) const {
return vImpl_->encode(line, addEOS, inference);
}
// list of token ids to single line, can perform detokenization
std::string Vocab::decode(const Words& sentence,
bool ignoreEOS) const {
return vImpl_->decode(sentence, ignoreEOS);
}
// number of vocabulary items
size_t Vocab::size() const { return vImpl_->size(); }
// number of vocabulary items
std::string Vocab::type() const { return vImpl_->type(); }
// return EOS symbol id
Word Vocab::getEosId() const { return vImpl_->getEosId(); }
// return UNK symbol id
Word Vocab::getUnkId() const { return vImpl_->getUnkId(); }
} // namespace marian
#endif

73
src/data/factored_vocab.h Executable file
View File

@ -0,0 +1,73 @@
#pragma once
#include "common/definitions.h"
#include "data/types.h"
#include "common/options.h"
#include "common/file_stream.h"
namespace marian {
class IVocab;
// Wrapper around vocabulary types. Can choose underlying
// vocabulary implementation (vImpl_) based on speficied path
// and suffix.
// Vocabulary implementations can currently be:
// * DefaultVocabulary for YAML (*.yml and *.yaml) and TXT (any other non-specific ending)
// * SentencePiece with suffix *.spm (works, but has to be created outside Marian)
class Vocab {
private:
Ptr<IVocab> vImpl_;
Ptr<Options> options_;
size_t batchIndex_;
public:
Vocab(Ptr<Options> options, size_t batchIndex)
: options_(options), batchIndex_(batchIndex) {}
size_t loadOrCreate(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
size_t maxSize = 0);
size_t load(const std::string& vocabPath, size_t maxSize = 0);
void create(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
size_t maxSize);
void create(const std::string& vocabPath,
const std::string& trainPath,
size_t maxSize);
// string token to token id
Word operator[](const std::string& word) const;
// token index to string token
const std::string& operator[](Word word) const;
// line of text to list of token ids, can perform tokenization
Words encode(const std::string& line,
bool addEOS = true,
bool inference = false) const;
// list of token ids to single line, can perform detokenization
std::string decode(const Words& sentence,
bool ignoreEOS = true) const;
// number of vocabulary items
size_t size() const;
// number of vocabulary items
std::string type() const;
// return EOS symbol id
Word getEosId() const;
// return UNK symbol id
Word getUnkId() const;
// create fake vocabulary for collecting batch statistics
void createFake();
};
} // namespace marian

View File

@ -556,6 +556,7 @@
<ClCompile Include="..\src\common\version.cpp" />
<ClCompile Include="..\src\data\alignment.cpp" />
<ClCompile Include="..\src\data\default_vocab.cpp" />
<ClCompile Include="..\src\data\factored_vocab.cpp" />
<ClCompile Include="..\src\data\sentencepiece_vocab.cpp" />
<ClCompile Include="..\src\data\vocab.cpp" />
<ClCompile Include="..\src\data\corpus_base.cpp" />
@ -701,6 +702,7 @@
<ClInclude Include="..\src\common\timer.h" />
<ClInclude Include="..\src\common\types.h" />
<ClInclude Include="..\src\common\version.h" />
<ClInclude Include="..\src\data\factored_vocab.h" />
<ClInclude Include="..\src\data\vocab_base.h" />
<ClInclude Include="..\src\examples\mnist\dataset.h" />
<ClInclude Include="..\src\examples\mnist\model.h" />

View File

@ -484,6 +484,9 @@
<ClCompile Include="..\src\layers\generic.cpp">
<Filter>layers</Filter>
</ClCompile>
<ClCompile Include="..\src\data\factored_vocab.cpp">
<Filter>data</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ClInclude Include="..\src\marian.h" />
@ -1531,6 +1534,9 @@
<ClInclude Include="..\src\data\vocab_base.h">
<Filter>data</Filter>
</ClInclude>
<ClInclude Include="..\src\data\factored_vocab.h">
<Filter>data</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="3rd_party">