create sentencepiece within Marian

This commit is contained in:
Marcin Junczys-Dowmunt 2018-11-23 23:12:55 -08:00
parent 1474edfbdb
commit c0ba4d8307
3 changed files with 49 additions and 8 deletions

View File

@ -64,7 +64,7 @@ if(USE_SENTENCEPIECE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SENTENCEPIECE")
LIST(APPEND CUDA_NVCC_FLAGS -DUSE_SENTENCEPIECE; )
set(EXT_LIBS ${EXT_LIBS} sentencepiece)
set(EXT_LIBS ${EXT_LIBS} sentencepiece sentencepiece_train)
endif()

View File

@ -40,7 +40,7 @@ namespace io {
class TemporaryFile {
private:
int fd_;
int fd_{-1};
bool unlink_;
std::string name_;

View File

@ -2,6 +2,7 @@
#ifdef USE_SENTENCEPIECE
#include "sentencepiece/src/sentencepiece_processor.h"
#include "sentencepiece/src/sentencepiece_trainer.h"
#endif
#include "common/options.h"
@ -9,6 +10,8 @@
#include "common/filesystem.h"
#include "common/regex.h"
#include <sstream>
namespace marian {
#ifdef USE_SENTENCEPIECE
@ -60,8 +63,46 @@ public:
void create(const std::string& vocabPath,
const std::unordered_map<std::string, size_t>& counter,
size_t maxSize) override {
ABORT("[data] Creating SentencePieceVocab not supported");
size_t /*maxSize*/) override {
// Create temporary tsv file with vocabulary counts
io::TemporaryFile temp(options_->get<std::string>("tempdir"), false);
std::string fileName = temp.getFileName();
LOG(info, "[data] Creating tsv in temporary file {}", fileName);
{
io::OutputFileStream out(temp);
for(const auto& it : counter)
out << it.first << "\t" << it.second << std::endl;
}
// @TODO: expose parameters
// @TODO: compute joint vocabs
std::stringstream command;
command
<< " --input_format=tsv"
<< " --bos_id=-1 --eos_id=0 --unk_id=1"
<< " --input=" << fileName
<< " --model_prefix=" << vocabPath
<< " --vocab_size=1000";
const auto status = sentencepiece::SentencePieceTrainer::Train(command.str());
ABORT_IF(!status.ok(),
"SentencePieceVocab error: {}",
status.ToString());
LOG(info, "[data] Removing {}", vocabPath + ".vocab");
ABORT_IF(remove((vocabPath + ".vocab").c_str()) != 0,
"Could not remove {}",
vocabPath + ".vocab");
LOG(info, "[data] Renaming {} to {}", vocabPath + ".model", vocabPath);
ABORT_IF(rename((vocabPath + ".model").c_str(), vocabPath.c_str()) != 0,
"Could not rename {} to {}",
vocabPath + ".model", vocabPath);
}
void createFake() {
@ -107,15 +148,15 @@ public:
LOG(info, "[data] Loading SentencePieceVocab from file {}", vocabPath);
ABORT_IF(!filesystem::exists(vocabPath),
"SentencePieceVocab file {} does not exits",
vocabPath);
"SentencePieceVocab file {} does not exits",
vocabPath);
spm_.reset(new sentencepiece::SentencePieceProcessor());
const auto status = spm_->Load(vocabPath);
ABORT_IF(!status.ok(),
"SentencePieceVocab error: {}",
status.ToString());
"SentencePieceVocab error: {}",
status.ToString());
return spm_->GetPieceSize();
}