mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
create sentencepiece within Marian
This commit is contained in:
parent
1474edfbdb
commit
c0ba4d8307
@ -64,7 +64,7 @@ if(USE_SENTENCEPIECE)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SENTENCEPIECE")
|
||||
LIST(APPEND CUDA_NVCC_FLAGS -DUSE_SENTENCEPIECE; )
|
||||
set(EXT_LIBS ${EXT_LIBS} sentencepiece)
|
||||
set(EXT_LIBS ${EXT_LIBS} sentencepiece sentencepiece_train)
|
||||
endif()
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ namespace io {
|
||||
|
||||
class TemporaryFile {
|
||||
private:
|
||||
int fd_;
|
||||
int fd_{-1};
|
||||
bool unlink_;
|
||||
std::string name_;
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#ifdef USE_SENTENCEPIECE
|
||||
#include "sentencepiece/src/sentencepiece_processor.h"
|
||||
#include "sentencepiece/src/sentencepiece_trainer.h"
|
||||
#endif
|
||||
|
||||
#include "common/options.h"
|
||||
@ -9,6 +10,8 @@
|
||||
#include "common/filesystem.h"
|
||||
#include "common/regex.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace marian {
|
||||
|
||||
#ifdef USE_SENTENCEPIECE
|
||||
@ -60,8 +63,46 @@ public:
|
||||
|
||||
void create(const std::string& vocabPath,
|
||||
const std::unordered_map<std::string, size_t>& counter,
|
||||
size_t maxSize) override {
|
||||
ABORT("[data] Creating SentencePieceVocab not supported");
|
||||
size_t /*maxSize*/) override {
|
||||
|
||||
// Create temporary tsv file with vocabulary counts
|
||||
io::TemporaryFile temp(options_->get<std::string>("tempdir"), false);
|
||||
std::string fileName = temp.getFileName();
|
||||
|
||||
LOG(info, "[data] Creating tsv in temporary file {}", fileName);
|
||||
|
||||
{
|
||||
io::OutputFileStream out(temp);
|
||||
for(const auto& it : counter)
|
||||
out << it.first << "\t" << it.second << std::endl;
|
||||
}
|
||||
|
||||
// @TODO: expose parameters
|
||||
// @TODO: compute joint vocabs
|
||||
|
||||
std::stringstream command;
|
||||
command
|
||||
<< " --input_format=tsv"
|
||||
<< " --bos_id=-1 --eos_id=0 --unk_id=1"
|
||||
<< " --input=" << fileName
|
||||
<< " --model_prefix=" << vocabPath
|
||||
<< " --vocab_size=1000";
|
||||
|
||||
const auto status = sentencepiece::SentencePieceTrainer::Train(command.str());
|
||||
|
||||
ABORT_IF(!status.ok(),
|
||||
"SentencePieceVocab error: {}",
|
||||
status.ToString());
|
||||
|
||||
LOG(info, "[data] Removing {}", vocabPath + ".vocab");
|
||||
ABORT_IF(remove((vocabPath + ".vocab").c_str()) != 0,
|
||||
"Could not remove {}",
|
||||
vocabPath + ".vocab");
|
||||
|
||||
LOG(info, "[data] Renaming {} to {}", vocabPath + ".model", vocabPath);
|
||||
ABORT_IF(rename((vocabPath + ".model").c_str(), vocabPath.c_str()) != 0,
|
||||
"Could not rename {} to {}",
|
||||
vocabPath + ".model", vocabPath);
|
||||
}
|
||||
|
||||
void createFake() {
|
||||
@ -107,15 +148,15 @@ public:
|
||||
LOG(info, "[data] Loading SentencePieceVocab from file {}", vocabPath);
|
||||
|
||||
ABORT_IF(!filesystem::exists(vocabPath),
|
||||
"SentencePieceVocab file {} does not exits",
|
||||
vocabPath);
|
||||
"SentencePieceVocab file {} does not exits",
|
||||
vocabPath);
|
||||
|
||||
spm_.reset(new sentencepiece::SentencePieceProcessor());
|
||||
const auto status = spm_->Load(vocabPath);
|
||||
|
||||
ABORT_IF(!status.ok(),
|
||||
"SentencePieceVocab error: {}",
|
||||
status.ToString());
|
||||
"SentencePieceVocab error: {}",
|
||||
status.ToString());
|
||||
|
||||
return spm_->GetPieceSize();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user