Updated normalizer

This commit is contained in:
Taku Kudo 2018-06-04 20:32:37 +09:00
parent 4f7af0dfad
commit 4e3bcf1373
20 changed files with 538 additions and 255 deletions

View File

@ -1 +0,0 @@
20 20 # =>
1 20 20 # =>

View File

@ -25,7 +25,7 @@ setup(name = 'sentencepiece',
author_email='taku@google.com',
description = 'SentencePiece python wrapper',
long_description = long_description,
version='0.0.9',
version='0.1.0',
url = 'https://github.com/google/sentencepiece',
license = 'Apache',
platforms = 'Unix',

View File

@ -14,7 +14,6 @@
#include "bpe_model_trainer.h"
#include "builder.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
@ -47,7 +46,6 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("identity");
normalizer_spec.set_add_dummy_prefix(false);
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();
@ -82,7 +80,6 @@ TEST(BPETrainerTest, EndToEndTest) {
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("nfkc");
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
constexpr int kVocabSize = 8000;
trainer_spec.set_vocab_size(kVocabSize);

View File

@ -13,6 +13,7 @@
// limitations under the License.!
#include "builder.h"
#include <functional>
#ifdef ENABLE_NFKC_COMPILE
#include <unicode/errorcode.h>
@ -33,10 +34,13 @@
namespace sentencepiece {
namespace normalizer {
namespace {
constexpr int kMaxUnicode = 0x10FFFF;
static constexpr char kDefaultNormalizerName[] = "nfkc";
#ifdef ENABLE_NFKC_COMPILE
// Normalize |input| with ICU's normalizer with |mode|.
// Normalize `input` with ICU's normalizer with `mode`.
Builder::Chars UnicodeNormalize(UNormalizationMode mode,
const Builder::Chars &input) {
const std::string utf8 = string_util::UnicodeTextToUTF8(input);
@ -78,7 +82,7 @@ Builder::Chars ToNFD(const Builder::Chars &input) {
}
// Given an NFKD-normalized string, returns a set of all strings which are
// normalized into the same |nfkd|. |norm2orig| is the normalized to
// normalized into the same `nfkd`. `norm2orig` is the normalized to
// un-normalized character mapping.
std::vector<Builder::Chars> ExpandUnnormalized(
const Builder::Chars &nfkd,
@ -104,8 +108,8 @@ std::vector<Builder::Chars> ExpandUnnormalized(
}
#endif
// Normalizes |src| with |chars_map| and returns normalized Chars.
// |max_len| specifies the maximum length of the key in |chars_map|.
// Normalizes `src` with `chars_map` and returns normalized Chars.
// `max_len` specifies the maximum length of the key in `chars_map`.
Builder::Chars Normalize(const Builder::CharsMap &chars_map,
const Builder::Chars &src, int max_len) {
CHECK_GE(max_len, 1);
@ -146,7 +150,7 @@ util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
CHECK_OR_RETURN(output);
CHECK_OR_RETURN(!chars_map.empty());
LOG(INFO) << "Loading CharsMap of size " << chars_map.size();
LOG(INFO) << "Loading CharsMap of size=" << chars_map.size();
// Aggregates the same target strings to save footprint.
std::map<Chars, int> normalized2pos;
@ -201,7 +205,59 @@ util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
trie.size() * trie.unit_size());
*output = Normalizer::EncodePrecompiledCharsMap(trie_blob, normalized);
LOG(INFO) << "Generated normalizer blob. size= " << output->size();
LOG(INFO) << "Generated normalizer blob. size=" << output->size();
return util::OkStatus();
}
// static
util::Status Builder::DecompileCharsMap(StringPiece blob,
Builder::CharsMap *chars_map) {
CHECK_OR_RETURN(chars_map);
chars_map->clear();
StringPiece trie_blob, normalized;
RETURN_IF_ERROR(
Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob, &normalized));
Darts::DoubleArray trie;
trie.set_array(const_cast<char *>(trie_blob.data()),
trie_blob.size() / trie.unit_size());
std::string key;
std::function<void(size_t, size_t)> traverse;
// Given a Trie node at `node_pos` and the key position at `key_position`,
// Expands children nodes from `node_pos`.
// When leaf nodes are found, stores them into `chars_map`.
traverse = [&traverse, &key, &trie, &normalized, &chars_map](
size_t node_pos, size_t key_pos) -> void {
for (int c = 0; c <= 255; ++c) {
key.push_back(static_cast<char>(c));
size_t copied_node_pos = node_pos;
size_t copied_key_pos = key_pos;
// Note: `copied_(node|key)_pos` are non-const references.
// They store the new positions after node traversal.
const Darts::DoubleArray::result_type result = trie.traverse(
key.data(), copied_node_pos, copied_key_pos, key.size());
if (result >= -1) { // node exists.
if (result >= 0) { // has a value after transition.
const StringPiece value = normalized.data() + result;
Chars key_chars, value_chars;
for (const auto c : string_util::UTF8ToUnicodeText(key))
key_chars.push_back(c);
for (const auto c : string_util::UTF8ToUnicodeText(value))
value_chars.push_back(c);
(*chars_map)[key_chars] = value_chars;
}
// Recursively traverse.
traverse(copied_node_pos, copied_key_pos);
}
key.pop_back();
}
};
traverse(0, 0);
return util::OkStatus();
}
@ -209,6 +265,13 @@ util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
// static
util::Status Builder::GetPrecompiledCharsMap(const std::string &name,
std::string *output) {
CHECK_OR_RETURN(output);
if (name == "identity") {
output->clear();
return util::OkStatus();
}
std::string result;
for (size_t i = 0; i < kNormalizationRules_size; ++i) {
const auto *blob = &kNormalizationRules_blob[i];
@ -222,33 +285,7 @@ util::Status Builder::GetPrecompiledCharsMap(const std::string &name,
}
// static
util::Status Builder::PopulateNormalizerSpec(NormalizerSpec *normalizer_spec) {
CHECK_OR_RETURN(normalizer_spec);
if (!normalizer_spec->normalization_rule_tsv().empty()) {
CHECK_OR_RETURN(normalizer_spec->precompiled_charsmap().empty())
<< "precompiled_charsmap is already defined.";
const auto chars_map = normalizer::Builder::BuildMapFromFile(
normalizer_spec->normalization_rule_tsv());
RETURN_IF_ERROR(CompileCharsMap(
chars_map, normalizer_spec->mutable_precompiled_charsmap()));
normalizer_spec->set_name("user_defined");
} else {
if (normalizer_spec->name().empty()) {
normalizer_spec->set_name(kDefaultNormalizerName);
}
if (normalizer_spec->precompiled_charsmap().empty()) {
RETURN_IF_ERROR(GetPrecompiledCharsMap(
normalizer_spec->name(),
normalizer_spec->mutable_precompiled_charsmap()));
}
}
return util::OkStatus();
}
// static
Builder::CharsMap Builder::BuildNFKCMap() {
util::Status Builder::BuildNFKCMap(CharsMap *chars_map) {
#ifdef ENABLE_NFKC_COMPILE
LOG(INFO) << "Running BuildNFKCMap";
@ -286,7 +323,7 @@ Builder::CharsMap Builder::BuildNFKCMap() {
if (nfkc == nfkd) {
continue;
}
// Expand all possible sequences which are normalized into the same |nfkd|.
// Expand all possible sequences which are normalized into the same `nfkd`.
for (const auto &nfkd_orig : ExpandUnnormalized(nfkd, norm2orig)) {
if (nfkd_orig != nfkc) {
nfkc_map[nfkd_orig] = nfkc;
@ -294,61 +331,94 @@ Builder::CharsMap Builder::BuildNFKCMap() {
}
}
return RemoveRedundantMap(nfkc_map);
RETURN_IF_ERROR(RemoveRedundantMap(&nfkc_map));
*chars_map = std::move(nfkc_map);
#else
LOG(FATAL) << "NFKC compile is not enabled."
LOG(ERROR) << "NFKC compile is not enabled."
<< " rebuild with ./configure --enable-nfkc-compile";
return {};
#endif
return util::OkStatus();
}
// static
Builder::CharsMap Builder::BuildIdentityMap() {
// Adds one dummy entry since empty rule is not allowed.
const CharsMap result = {{{0x0020}, {0x0020}}};
return result;
}
// static
Builder::CharsMap Builder::BuildMapFromFile(StringPiece filename) {
util::Status Builder::LoadCharsMap(StringPiece filename, CharsMap *chars_map) {
LOG(INFO) << "Loading maping file: " << filename.data();
CHECK_OR_RETURN(chars_map);
io::InputBuffer input(filename);
RETURN_IF_ERROR(input.status());
std::string line;
CharsMap chars_map;
chars_map->clear();
while (input.ReadLine(&line)) {
const auto fields = string_util::SplitPiece(line, "\t");
CHECK_GE(fields.size(), 2);
std::vector<char32> src, trg;
for (const auto &s : string_util::SplitPiece(fields[0], " ")) {
for (auto &s : string_util::SplitPiece(fields[0], " ")) {
s.Consume("U+");
src.push_back(string_util::HexToInt<char32>(s));
}
for (const auto &s : string_util::SplitPiece(fields[1], " ")) {
for (auto &s : string_util::SplitPiece(fields[1], " ")) {
s.Consume("U+");
trg.push_back(string_util::HexToInt<char32>(s));
}
CHECK(!src.empty());
CHECK(!trg.empty());
chars_map[src] = trg;
(*chars_map)[src] = trg;
}
return chars_map;
return util::OkStatus();
}
// static
Builder::CharsMap Builder::RemoveRedundantMap(const CharsMap &chars_map) {
CharsMap new_chars_map;
util::Status Builder::SaveCharsMap(StringPiece filename,
const Builder::CharsMap &chars_map) {
io::OutputBuffer output(filename);
RETURN_IF_ERROR(output.status());
for (const auto &c : chars_map) {
std::vector<std::string> src, trg;
string_util::UnicodeText srcu, trgu;
for (char32 v : c.first) {
src.push_back(string_util::IntToHex(v));
srcu.push_back(v);
}
for (char32 v : c.second) {
trg.push_back(string_util::IntToHex(v));
trgu.push_back(v);
}
std::string line = string_util::Join(src, " ") + "\t" +
string_util::Join(trg, " ") + "\t# " +
string_util::UnicodeTextToUTF8(c.first) + " => " +
string_util::UnicodeTextToUTF8(c.second);
line = string_util::StringReplace(line, "\n", " ", true);
line = string_util::StringReplace(line, "\r", " ", true);
output.WriteLine(line);
}
return util::OkStatus();
}
// static
util::Status Builder::RemoveRedundantMap(CharsMap *chars_map) {
CHECK_OR_RETURN(chars_map);
CharsMap new_chars_map;
size_t max_len = 0;
for (const auto &p : chars_map) {
for (const auto &p : *chars_map) {
max_len = std::max(p.first.size(), max_len);
if (p.first.size() == 1) {
new_chars_map.insert(p);
}
}
CHECK_GT(max_len, 0);
CHECK_GT_OR_RETURN(max_len, 0);
// Checks whether the rules with size of |len| can be normalized by
// Checks whether the rules with size of `len` can be normalized by
// the rules with size of [1 .. len - 1].
for (size_t len = 2; len <= max_len; ++len) {
for (const auto &p : chars_map) {
for (const auto &p : *chars_map) {
if (p.first.size() == len &&
p.second != Normalize(new_chars_map, p.first, len - 1)) {
new_chars_map.insert(p);
@ -356,12 +426,14 @@ Builder::CharsMap Builder::RemoveRedundantMap(const CharsMap &chars_map) {
}
}
// Verify all characters in |chars_map| are normalized by |new_chars_map|.
for (const auto &p : chars_map) {
CHECK_EQ(p.second, Normalize(new_chars_map, p.first, max_len));
// Verify all characters in `chars_map` are normalized by `new_chars_map`.
for (const auto &p : *chars_map) {
CHECK_EQ_OR_RETURN(p.second, Normalize(new_chars_map, p.first, max_len));
}
return new_chars_map;
*chars_map = std::move(new_chars_map);
return util::OkStatus();
}
} // namespace normalizer
} // namespace sentencepiece

View File

@ -45,14 +45,13 @@ class Builder {
static util::Status CompileCharsMap(const CharsMap &chars_map,
std::string *output);
// Decompiles `blob` into `chars_map`.
static util::Status DecompileCharsMap(StringPiece blob, CharsMap *chars_map);
// Returns a pre-compiled binary index with `name`.
static util::Status GetPrecompiledCharsMap(const std::string &name,
std::string *output);
// Populates necessary fields (precompiled_charmap) from
// `name` or `normalization_rule_tsv` fields in `normalizer_spec`.
static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec);
// Makes a normalization mapping based on NFKC.
//
// Note that Normalizer/Builder classes do not support
@ -88,16 +87,17 @@ class Builder {
// normalizer is the goal of SentencePiece.
//
// TODO(taku): Make NFC, NFD, and NFKD mapping if necessary.
static CharsMap BuildNFKCMap();
// Returns identity mapping, which dose not perform any normalization.
static CharsMap BuildIdentityMap();
static util::Status BuildNFKCMap(CharsMap *chars_map);
// Builds Chars map save in `filename`.
// Format:
// src_uchar1 src_uchar2 ... <tab> trg_uchar1 trg_uchar2...
// (src|trg)_ucharX must be a hex of UCS4.
static CharsMap BuildMapFromFile(StringPiece filename);
// (src|trg)_ucharX must be a hex of Unicode code point.
static util::Status LoadCharsMap(StringPiece filename, CharsMap *chars_map);
// Saves Chars map to `filename` as TSV.
static util::Status SaveCharsMap(StringPiece filename,
const CharsMap &chars_map);
private:
FRIEND_TEST(BuilderTest, RemoveRedundantMapTest);
@ -105,7 +105,7 @@ class Builder {
// Removes redundant rules from `chars_map`.
// When char_maps have "aa" => "bb" and "a" => "b", the first
// rule is not necessary since the second rule can cover the first rule.
static CharsMap RemoveRedundantMap(const CharsMap &chars_map);
static util::Status RemoveRedundantMap(CharsMap *chars_map);
};
} // namespace normalizer
} // namespace sentencepiece

View File

@ -15,6 +15,7 @@
#include "builder.h"
#include "common.h"
#include "normalizer.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "util.h"
@ -33,12 +34,12 @@ TEST(BuilderTest, RemoveRedundantMapTest) {
chars_map[{0x0061, 0x0062}] = {0x0041, 0x0042};
chars_map[{0x0061, 0x0062, 0x0063}] = {0x0043, 0x0042, 0x0041};
const auto new_chars_map = Builder::RemoveRedundantMap(chars_map);
EXPECT_EQ(3, new_chars_map.size());
EXPECT_EQ(new_chars_map.end(), new_chars_map.find({0x0061, 0x0062}));
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0061}));
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0062}));
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0061, 0x0062, 0x0063}));
EXPECT_OK(Builder::RemoveRedundantMap(&chars_map));
EXPECT_EQ(3, chars_map.size());
EXPECT_EQ(chars_map.end(), chars_map.find({0x0061, 0x0062}));
EXPECT_NE(chars_map.end(), chars_map.find({0x0061}));
EXPECT_NE(chars_map.end(), chars_map.find({0x0062}));
EXPECT_NE(chars_map.end(), chars_map.find({0x0061, 0x0062, 0x0063}));
}
TEST(BuilderTest, GetPrecompiledCharsMapWithInvalidNameTest) {
@ -47,25 +48,19 @@ TEST(BuilderTest, GetPrecompiledCharsMapWithInvalidNameTest) {
EXPECT_NOT_OK(Builder::GetPrecompiledCharsMap("__UNKNOWN__", &output));
}
TEST(BuilderTest, BuildIdentityMapTest) {
const auto m = Builder::BuildIdentityMap();
EXPECT_EQ(1, m.size());
}
TEST(BuilderTest, BuildNFKCMapTest) {
Builder::CharsMap chars_map;
#ifdef ENABLE_NFKC_COMPILE
const auto m = Builder::BuildNFKCMap();
EXPECT_TRUE(!m.empty());
EXPECT_OK(Builder::BuildNFKCMap(&chars_map));
EXPECT_TRUE(!chars_map.empty());
#else
EXPECT_DEATH(Builder::BuildNFKCMap());
// EXPECT_DEATH(Builder::BuildNFKCMap(&chars_map));
#endif
}
TEST(BuilderTest, GetPrecompiledCharsMapTest) {
{
NormalizerSpec spec;
spec.set_name("nfkc");
EXPECT_OK(Builder::PopulateNormalizerSpec(&spec));
const NormalizerSpec spec = SentencePieceTrainer::GetNormalizerSpec("nfkc");
const Normalizer normalizer(spec);
EXPECT_EQ(WS "ABC", normalizer.Normalize(""));
EXPECT_EQ(WS "(株)", normalizer.Normalize(""));
@ -73,9 +68,9 @@ TEST(BuilderTest, GetPrecompiledCharsMapTest) {
}
{
NormalizerSpec spec;
spec.set_name("identity");
EXPECT_OK(Builder::PopulateNormalizerSpec(&spec));
const NormalizerSpec spec =
SentencePieceTrainer::GetNormalizerSpec("identity");
EXPECT_TRUE(spec.precompiled_charsmap().empty());
const Normalizer normalizer(spec);
EXPECT_EQ(WS "", normalizer.Normalize(""));
EXPECT_EQ(WS "", normalizer.Normalize(""));
@ -99,6 +94,11 @@ TEST(BuilderTest, CompileCharsMap) {
NormalizerSpec spec;
EXPECT_OK(
Builder::CompileCharsMap(chars_map, spec.mutable_precompiled_charsmap()));
Builder::CharsMap decompiled_chars_map;
EXPECT_OK(Builder::DecompileCharsMap(spec.precompiled_charsmap(),
&decompiled_chars_map));
EXPECT_EQ(chars_map, decompiled_chars_map);
spec.set_add_dummy_prefix(false);
const Normalizer normalizer(spec);
@ -112,12 +112,30 @@ TEST(BuilderTest, CompileCharsMap) {
EXPECT_EQ("ABCabcD", normalizer.Normalize("abcあいうd"));
}
TEST(BuilderTest, BuildMapFromFileTest) {
const auto cmap = Builder::BuildMapFromFile("../data/nfkc.tsv");
std::string expected, precompiled;
EXPECT_OK(Builder::CompileCharsMap(cmap, &precompiled));
EXPECT_OK(Builder::GetPrecompiledCharsMap("nfkc", &expected));
EXPECT_EQ(expected, precompiled);
TEST(BuilderTest, LoadCharsMapTest) {
Builder::CharsMap chars_map;
EXPECT_OK(Builder::LoadCharsMap("../data/nfkc.tsv", &chars_map));
std::string precompiled, expected;
EXPECT_OK(Builder::CompileCharsMap(chars_map, &precompiled));
// Round-trip.
Builder::CharsMap decompiled_chars_map;
EXPECT_OK(Builder::DecompileCharsMap(precompiled, &decompiled_chars_map));
EXPECT_EQ(chars_map, decompiled_chars_map);
test::ScopedTempFile output_tsv("output.tsv");
EXPECT_OK(Builder::SaveCharsMap(output_tsv.filename(), chars_map));
Builder::CharsMap saved_chars_map;
EXPECT_OK(Builder::LoadCharsMap(output_tsv.filename(), &saved_chars_map));
EXPECT_EQ(chars_map, saved_chars_map);
#ifdef ENABLE_NFKC_COMPILE
Builder::CharsMap nfkc_map;
EXPECT_OK(Builder::BuildNFKCMap(&nfkc_map));
EXPECT_OK(Builder::CompileCharsMap(nfkc_map, &expected));
#endif
}
TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
@ -129,7 +147,7 @@ TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
chars_map[keys] = {'b'};
}
std::string output;
EXPECT_NOT_OK(Builder::CompileCharsMap(chars_map, &output));
EXPECT_FALSE(Builder::CompileCharsMap(chars_map, &output).ok());
}
} // namespace normalizer

View File

@ -13,7 +13,6 @@
// limitations under the License.!
#include "char_model_trainer.h"
#include "builder.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
@ -45,7 +44,6 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("identity");
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
Trainer trainer(trainer_spec, normalizer_spec);
trainer.Train();

View File

@ -18,34 +18,17 @@
#include <string>
#include "builder.h"
#include "flags.h"
#include "sentencepiece_processor.h"
#include "stringpiece.h"
#include "util.h"
using sentencepiece::normalizer::Builder;
using sentencepiece::util::Status;
DEFINE_bool(output_precompiled_header, false, "make normalization_rule.h file");
namespace sentencepiece {
namespace {
void WriteTSV(const Builder::CharsMap &cmap, StringPiece filename) {
sentencepiece::io::OutputBuffer output(filename);
for (const auto &c : cmap) {
std::vector<std::string> src, trg;
for (char32 v : c.first) {
src.push_back(string_util::IntToHex(v));
}
for (char32 v : c.second) {
trg.push_back(string_util::IntToHex(v));
}
std::string line = string_util::Join(src, " ") + "\t" +
string_util::Join(trg, " ") + "\t# " +
string_util::UnicodeTextToUTF8(c.first) + " => " +
string_util::UnicodeTextToUTF8(c.second);
line = string_util::StringReplace(line, "\n", " ", true);
line = string_util::StringReplace(line, "\r", " ", true);
output.WriteLine(line);
}
}
std::string ToHexData(StringPiece data) {
const char *begin = data.data();
@ -81,9 +64,9 @@ std::string ToHexData(StringPiece data) {
int main(int argc, char **argv) {
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
const std::vector<std::pair<std::string, std::function<Builder::CharsMap()>>>
kRuleList = {{"nfkc", Builder::BuildNFKCMap},
{"identity", Builder::BuildIdentityMap}};
const std::vector<
std::pair<std::string, std::function<Status(Builder::CharsMap *)>>>
kRuleList = {{"nfkc", Builder::BuildNFKCMap}};
constexpr char kHeader[] =
R"(#ifndef NORMALIZATION_RULE_H_
@ -107,13 +90,18 @@ constexpr BinaryBlob kNormalizationRules_blob[] = {)";
os << kHeader;
for (const auto &p : kRuleList) {
const auto normalized_map = p.second();
Builder::CharsMap normalized_map;
CHECK_OK(p.second(&normalized_map));
// Write Header.
std::string index;
CHECK_OK(Builder::CompileCharsMap(normalized_map, &index));
os << "{ \"" << p.first << "\", " << index.size() << ",\n";
os << sentencepiece::ToHexData(index);
os << " },";
sentencepiece::WriteTSV(normalized_map, p.first + ".tsv");
// Write TSV file.
CHECK_OK(Builder::SaveCharsMap(p.first + ".tsv", normalized_map));
}
os << "};\n";
@ -124,6 +112,7 @@ constexpr BinaryBlob kNormalizationRules_blob[] = {)";
if (FLAGS_output_precompiled_header) {
constexpr char kPrecompiledHeaderFileName[] = "normalization_rule.h";
sentencepiece::io::OutputBuffer output(kPrecompiledHeaderFileName);
CHECK_OK(output.status());
output.Write(os.str());
}

View File

@ -57,14 +57,14 @@ bool ModelInterface::IsUnknown(int id) const {
void ModelInterface::InitializePieces(bool enable_user_defined) {
pieces_.clear();
reserved_id_map_.clear();
unk_id_ = 0;
unk_id_ = -1;
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
const auto &sp = model_proto_->pieces(i);
if (!enable_user_defined &&
sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
status_ = util::StatusBuilder(util::error::INTERNAL)
<< "user defined symbol is not supported.";
<< "User defined symbol is not supported.";
return;
}
@ -78,8 +78,19 @@ void ModelInterface::InitializePieces(bool enable_user_defined) {
return;
}
if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
if (sp.type() == ModelProto::SentencePiece::UNKNOWN) {
if (unk_id_ >= 0) {
status_ = util::StatusBuilder(util::error::INTERNAL)
<< "unk is already defined.";
return;
}
unk_id_ = i;
}
}
if (unk_id_ == -1)
status_ = util::StatusBuilder(util::error::INTERNAL)
<< "unk is not defined.";
}
std::vector<StringPiece> SplitIntoWords(StringPiece text) {

View File

@ -23,36 +23,32 @@ namespace normalizer {
constexpr int Normalizer::kMaxTrieResultsSize;
Normalizer::Normalizer(const NormalizerSpec &spec) : spec_(&spec) {
Normalizer::Normalizer(const NormalizerSpec &spec)
: spec_(&spec), status_(util::OkStatus()) {
StringPiece index = spec.precompiled_charsmap();
if (index.empty()) {
status_ = util::InvalidArgumentError("precompiled_charsmap is empty.");
return;
LOG(INFO) << "precompiled_charsmap is empty. use identity normalization.";
} else {
StringPiece trie_blob, normalized;
status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized);
if (!status_.ok()) return;
// Reads the body of double array.
trie_ = port::MakeUnique<Darts::DoubleArray>();
// The second arg of set_array is not the size of blob,
// but the number of double array units.
trie_->set_array(const_cast<char *>(trie_blob.data()),
trie_blob.size() / trie_->unit_size());
normalized_ = normalized.data();
}
StringPiece trie_blob, normalized;
status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized);
if (!status_.ok()) return;
// Reads the body of double array.
trie_ = port::MakeUnique<Darts::DoubleArray>();
// The second arg of set_array is not the size of blob,
// but the number of double array units.
trie_->set_array(const_cast<char *>(trie_blob.data()),
trie_blob.size() / trie_->unit_size());
normalized_ = normalized.data();
}
Normalizer::~Normalizer() {}
util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
std::vector<size_t> *norm_to_orig) const {
if (trie_ == nullptr || normalized_ == nullptr) {
return util::InternalError("Normalizer model is not available.");
}
norm_to_orig->clear();
normalized->clear();
@ -60,6 +56,8 @@ util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
return util::OkStatus();
}
RETURN_IF_ERROR(status());
int consumed = 0;
// Ignores heading space.
@ -144,7 +142,7 @@ util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
const StringPiece space = spec_->escape_whitespaces() ? kSpaceSymbol : " ";
while (string_util::EndsWith(*normalized, space)) {
const int length = normalized->size() - space.size();
if (length < 0) return util::InternalError("length < 0");
CHECK_GE_OR_RETURN(length, 0);
consumed = (*norm_to_orig)[length];
normalized->resize(length);
norm_to_orig->resize(length);
@ -153,9 +151,7 @@ util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
norm_to_orig->push_back(consumed);
if (norm_to_orig->size() != normalized->size() + 1) {
return util::InternalError("norm_to_org and normalized are inconsistent");
}
CHECK_EQ_OR_RETURN(norm_to_orig->size(), normalized->size() + 1);
return util::OkStatus();
}
@ -173,25 +169,27 @@ std::pair<StringPiece, int> Normalizer::NormalizePrefix(
if (input.empty()) return result;
// Allocates trie_results in stack, which makes the encoding speed 36% faster.
// (38k sentences/sec => 60k sentences/sec).
// Builder checks that the result size never exceeds kMaxTrieResultsSize.
// This array consumes 0.5kByte in stack, which is less than
// default stack frames (16kByte).
Darts::DoubleArray::result_pair_type
trie_results[Normalizer::kMaxTrieResultsSize];
const size_t num_nodes =
trie_->commonPrefixSearch(input.data(), trie_results,
Normalizer::kMaxTrieResultsSize, input.size());
// Finds the longest rule.
size_t longest_length = 0;
int longest_value = 0;
for (size_t k = 0; k < num_nodes; ++k) {
if (longest_length == 0 || trie_results[k].length > longest_length) {
longest_length = trie_results[k].length; // length of prefix
longest_value = trie_results[k].value; // pointer to |normalized_|.
if (trie_ != nullptr) {
// Allocates trie_results in stack, which makes the encoding speed 36%
// faster. (38k sentences/sec => 60k sentences/sec). Builder checks that the
// result size never exceeds kMaxTrieResultsSize. This array consumes
// 0.5kByte in stack, which is less than default stack frames (16kByte).
Darts::DoubleArray::result_pair_type
trie_results[Normalizer::kMaxTrieResultsSize];
const size_t num_nodes = trie_->commonPrefixSearch(
input.data(), trie_results, Normalizer::kMaxTrieResultsSize,
input.size());
// Finds the longest rule.
for (size_t k = 0; k < num_nodes; ++k) {
if (longest_length == 0 || trie_results[k].length > longest_length) {
longest_length = trie_results[k].length; // length of prefix
longest_value = trie_results[k].value; // pointer to |normalized_|.
}
}
}
@ -239,7 +237,7 @@ util::Status Normalizer::DecodePrecompiledCharsMap(StringPiece blob,
!string_util::DecodePOD<uint32>(
StringPiece(blob.data(), sizeof(trie_blob_size)), &trie_blob_size) ||
trie_blob_size >= blob.size()) {
return util::InternalError("Trie blob is broken.");
return util::InternalError("Blob for normalization rule is broken.");
}
blob.remove_prefix(sizeof(trie_blob_size));

View File

@ -14,6 +14,7 @@
#include "normalizer.h"
#include "builder.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "util.h"
@ -23,20 +24,14 @@ namespace {
// Space symbol
#define WS "\xe2\x96\x81"
// Replacement char
#define RC "\xEF\xBF\xBD"
NormalizerSpec MakeDefaultSpec() {
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("nfkc");
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
return normalizer_spec;
return SentencePieceTrainer::GetNormalizerSpec("nfkc");
}
} // namespace
TEST(NormalizerTest, NormalizeErrorTest) {
NormalizerSpec spec;
Normalizer normalizer(spec);
EXPECT_NOT_OK(normalizer.Normalize("test", nullptr, nullptr));
}
TEST(NormalizerTest, NormalizeTest) {
auto spec = MakeDefaultSpec();
const Normalizer normalizer(spec);
@ -211,6 +206,44 @@ TEST(NormalizeTest, NomalizeWithSpaceContainedRules) {
EXPECT_EQ(" A F G ", normalizer.Normalize("ad"));
EXPECT_EQ(" A F G B", normalizer.Normalize("adb"));
}
// Added several corner cases around spaces.
struct SpacePattern {
bool add_dummy_prefix;
bool remove_extra_whitespaces;
bool escape_whitespaces;
const char *input;
const char *expected;
};
constexpr SpacePattern kSpacePatternData[] = {
{false, false, false, WS, WS}, {false, false, true, WS, WS},
{false, true, false, WS, WS}, {false, true, true, WS, ""},
{true, false, false, WS, " " WS}, {true, false, true, WS, WS WS},
{true, true, false, WS, " " WS}, {true, true, true, WS, ""},
{false, false, false, " ", " "}, {false, false, true, " ", WS},
{false, true, false, " ", ""}, {false, true, true, " ", ""},
{true, false, false, " ", " "}, {true, false, true, " ", WS WS},
{true, true, false, " ", ""}, {true, true, true, " ", ""}};
for (const auto &c : kSpacePatternData) {
spec.set_add_dummy_prefix(c.add_dummy_prefix);
spec.set_remove_extra_whitespaces(c.remove_extra_whitespaces);
spec.set_escape_whitespaces(c.escape_whitespaces);
const Normalizer normalizer(spec);
EXPECT_EQ(c.expected, normalizer.Normalize(c.input));
}
}
TEST(NormalizerTest, NormalizeReplacementChar) {
auto spec = MakeDefaultSpec();
spec.set_add_dummy_prefix(false);
const Normalizer normalizer(spec);
EXPECT_EQ("abc" RC "xy", normalizer.Normalize("abc\x80xy"));
EXPECT_EQ("abc" RC, normalizer.Normalize("abc\xc3"));
EXPECT_EQ("ab" RC RC "xy", normalizer.Normalize("ab\xe3\x81xy"));
EXPECT_EQ("a" RC RC RC "xy", normalizer.Normalize("a\xf3\x81\x81xy"));
EXPECT_EQ("ab" RC RC "xy", normalizer.Normalize("ab\xc0\x82xy"));
}
TEST(NormalizerTest, NormalizeFullTest) {
@ -310,6 +343,12 @@ TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest) {
TEST(NormalizerTest, StatusTest) {
NormalizerSpec spec;
{
const Normalizer normalizer(spec);
EXPECT_OK(normalizer.status()); // fallback to identity.
}
{
spec.set_precompiled_charsmap("x");
const Normalizer normalizer(spec);
EXPECT_FALSE(normalizer.status().ok());
}

View File

@ -19,6 +19,7 @@
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "stringpiece.h"
#include "testharness.h"
#include "util.h"
@ -100,10 +101,7 @@ std::vector<std::string> GetSpVec(const SentencePieceText &spt) {
}
NormalizerSpec MakeDefaultNormalizerSpec() {
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("nfkc");
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
return normalizer_spec;
return SentencePieceTrainer::GetNormalizerSpec("nfkc");
}
TEST(SentencepieceProcessorTest, StatusTest) {

View File

@ -25,6 +25,9 @@
#include "util.h"
namespace sentencepiece {
namespace {
static constexpr char kDefaultNormalizerName[] = "nfkc";
} // namespace
// static
util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
@ -36,13 +39,21 @@ util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
util::Status SentencePieceTrainer::Train(
const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
auto copied_normalizer_spec = normalizer_spec;
RETURN_IF_ERROR(
normalizer::Builder::PopulateNormalizerSpec(&copied_normalizer_spec));
RETURN_IF_ERROR(PopulateNormalizerSpec(&copied_normalizer_spec));
auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
return trainer->Train();
}
// static
NormalizerSpec SentencePieceTrainer::GetNormalizerSpec(
const std::string &name) {
NormalizerSpec spec;
spec.set_name(name);
CHECK_OK(normalizer::Builder::GetPrecompiledCharsMap(
spec.name(), spec.mutable_precompiled_charsmap()));
return spec;
}
// static
util::Status SentencePieceTrainer::SetProtoField(
const std::string &field_name, const std::string &value,
@ -161,8 +172,36 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
util::Status SentencePieceTrainer::Train(const std::string &args) {
TrainerSpec trainer_spec;
NormalizerSpec normalizer_spec;
CHECK_OK(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
RETURN_IF_ERROR(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
return Train(trainer_spec, normalizer_spec);
}
// static
util::Status SentencePieceTrainer::PopulateNormalizerSpec(
NormalizerSpec *normalizer_spec) {
CHECK_OR_RETURN(normalizer_spec);
if (!normalizer_spec->normalization_rule_tsv().empty()) {
CHECK_OR_RETURN(normalizer_spec->precompiled_charsmap().empty())
<< "precompiled_charsmap is already defined.";
normalizer::Builder::CharsMap chars_map;
RETURN_IF_ERROR(normalizer::Builder::LoadCharsMap(
normalizer_spec->normalization_rule_tsv(), &chars_map));
RETURN_IF_ERROR(normalizer::Builder::CompileCharsMap(
chars_map, normalizer_spec->mutable_precompiled_charsmap()));
normalizer_spec->set_name("user_defined");
} else {
if (normalizer_spec->name().empty()) {
normalizer_spec->set_name(kDefaultNormalizerName);
}
if (normalizer_spec->precompiled_charsmap().empty()) {
RETURN_IF_ERROR(normalizer::Builder::GetPrecompiledCharsMap(
normalizer_spec->name(),
normalizer_spec->mutable_precompiled_charsmap()));
}
}
return util::OkStatus();
}
} // namespace sentencepiece

View File

@ -45,6 +45,15 @@ class SentencePieceTrainer {
// '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
static util::Status Train(const std::string &args);
// Handy function to make a normalizer spec from the pre-compiled
// normalization name. Do not use this method in production as it crashes
// when `name` is invalid. Useful for unittesting.
static NormalizerSpec GetNormalizerSpec(const std::string &name);
// Populates necessary fields (precompiled_charmap) from
// `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`.
static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec);
// Overrides `trainer_spec` and `normalizer_spec` with the
// command-line string in `args`.
static util::Status MergeSpecsFromArgs(const std::string &args,

View File

@ -19,6 +19,7 @@
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
#include "util.h"
DEFINE_string(model, "", "Model file name");
@ -32,24 +33,30 @@ DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
DEFINE_bool(remove_extra_whitespaces, true, "Remove extra whitespaces");
DEFINE_string(output, "", "Output filename");
using sentencepiece::ModelProto;
using sentencepiece::NormalizerSpec;
using sentencepiece::SentencePieceProcessor;
using sentencepiece::SentencePieceTrainer;
using sentencepiece::normalizer::Builder;
using sentencepiece::normalizer::Normalizer;
int main(int argc, char *argv[]) {
std::vector<std::string> rest_args;
sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args);
sentencepiece::NormalizerSpec spec;
NormalizerSpec spec;
if (!FLAGS_model.empty()) {
sentencepiece::SentencePieceProcessor sp;
ModelProto model_proto;
SentencePieceProcessor sp;
CHECK_OK(sp.Load(FLAGS_model));
spec = sp.model_proto().normalizer_spec();
} else if (!FLAGS_normalization_rule_tsv.empty()) {
spec.set_normalization_rule_tsv(FLAGS_normalization_rule_tsv);
CHECK_OK(Builder::PopulateNormalizerSpec(&spec));
CHECK_OK(SentencePieceTrainer::PopulateNormalizerSpec(&spec));
} else if (!FLAGS_normalization_rule_name.empty()) {
spec.set_name(FLAGS_normalization_rule_name);
CHECK_OK(Builder::PopulateNormalizerSpec(&spec));
CHECK_OK(SentencePieceTrainer::PopulateNormalizerSpec(&spec));
} else {
LOG(FATAL) << "Sets --model, normalization_rule_tsv, or "
"normalization_rule_name flag.";
@ -62,7 +69,7 @@ int main(int argc, char *argv[]) {
spec.set_remove_extra_whitespaces(FLAGS_remove_extra_whitespaces);
}
sentencepiece::normalizer::Normalizer normalizer(spec);
const Normalizer normalizer(spec);
sentencepiece::io::OutputBuffer output(FLAGS_output);
CHECK_OK(output.status());

View File

@ -283,29 +283,37 @@ void TrainerInterface::SplitSentencesByWhitespace() {
}
util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
RETURN_IF_ERROR(status());
// Duplicated sentencepiece is not allowed.
std::unordered_set<std::string> dup;
std::set<std::string> dup;
#define CHECK_PIECE(piece) \
CHECK_OR_RETURN(string_util::IsStructurallyValid(piece)); \
CHECK_OR_RETURN(!piece.empty()); \
CHECK_OR_RETURN(dup.insert(piece).second) << piece << " is already defined";
for (const auto &w : meta_pieces_) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w.first);
sp->set_type(w.second);
sp->set_score(0.0);
CHECK_NE_OR_RETURN(ModelProto::SentencePiece::NORMAL, sp->type());
CHECK_PIECE(sp->piece());
size_t fid = 0;
for (int id = 0; id < trainer_spec_.vocab_size(); ++id) {
const auto it = meta_pieces_.find(id);
if (it != meta_pieces_.end()) {
auto *sp = model_proto->add_pieces();
sp->set_piece(it->second.first);
sp->set_type(it->second.second);
sp->set_score(0.0);
CHECK_EQ_OR_RETURN(model_proto->pieces_size() - 1, it->first);
CHECK_NE_OR_RETURN(ModelProto::SentencePiece::NORMAL, sp->type());
CHECK_PIECE(sp->piece());
} else if (fid < final_pieces_.size()) {
const auto &w = final_pieces_[fid++];
auto *sp = model_proto->add_pieces();
sp->set_piece(w.first);
sp->set_score(w.second);
CHECK_PIECE(sp->piece());
}
}
for (const auto &w : final_pieces_) {
auto *sp = model_proto->add_pieces();
sp->set_piece(w.first);
sp->set_score(w.second);
CHECK_PIECE(sp->piece());
}
CHECK_EQ_OR_RETURN(fid, final_pieces_.size());
*(model_proto->mutable_trainer_spec()) = trainer_spec_;
*(model_proto->mutable_normalizer_spec()) = normalizer_spec_;
@ -314,13 +322,13 @@ util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
trainer_spec_.model_type() == TrainerSpec::CHAR) {
CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(), model_proto->pieces_size());
CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(),
static_cast<int>(dup.size()));
static_cast<int32>(dup.size()));
model_proto->mutable_trainer_spec()->set_vocab_size(
model_proto->pieces_size());
} else {
CHECK_OR_RETURN(trainer_spec_.vocab_size() == model_proto->pieces_size() &&
trainer_spec_.vocab_size() == static_cast<int>(dup.size()))
<< "Use --hard_vocab_limit=false to make the vocab size `soft limit`.";
CHECK_EQ_OR_RETURN(trainer_spec_.vocab_size(), model_proto->pieces_size());
CHECK_EQ_OR_RETURN(trainer_spec_.vocab_size(),
static_cast<int32>(dup.size()));
}
return util::OkStatus();
@ -361,40 +369,64 @@ util::Status TrainerInterface::Save() const {
util::Status TrainerInterface::InitMetaPieces() {
CHECK_OR_RETURN(meta_pieces_.empty());
std::vector<std::pair<int, std::string>> ids;
if (trainer_spec_.unk_id() >= 0)
ids.emplace_back(trainer_spec_.unk_id(), kUNK);
if (trainer_spec_.bos_id() >= 0)
ids.emplace_back(trainer_spec_.bos_id(), kBOS);
if (trainer_spec_.eos_id() >= 0)
ids.emplace_back(trainer_spec_.eos_id(), kEOS);
if (trainer_spec_.pad_id() >= 0)
ids.emplace_back(trainer_spec_.pad_id(), kPAD);
std::sort(ids.begin(), ids.end());
int prev_id = -1;
bool has_unk = false;
for (const auto &p : ids) {
CHECK_EQ_OR_RETURN(prev_id + 1, p.first)
<< "ID for `" << p.second << "` must be " << prev_id + 1;
prev_id = p.first;
CHECK_EQ_OR_RETURN(static_cast<int>(meta_pieces_.size()), p.first);
if (p.second == kUNK) has_unk = true;
meta_pieces_.emplace_back(
p.second, (p.second == kUNK ? ModelProto::SentencePiece::UNKNOWN
: ModelProto::SentencePiece::CONTROL));
}
auto insert_id = [&has_unk, this](int id, const std::string &w) -> bool {
if (id < 0) return true;
if (id >= trainer_spec_.vocab_size() ||
meta_pieces_.find(id) != meta_pieces_.end() || (has_unk && w == kUNK))
return false;
if (w == kUNK) has_unk = true;
meta_pieces_[id] =
std::make_pair(w, w == kUNK ? ModelProto::SentencePiece::UNKNOWN
: ModelProto::SentencePiece::CONTROL);
return true;
};
CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), kUNK));
CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), kBOS));
CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), kEOS));
CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), kPAD));
CHECK_OR_RETURN(has_unk) << kUNK << " must be defined.";
std::set<std::string> dup;
int id = 0;
auto insert_meta_symbol = [&id, &dup, this](
const std::string &w,
ModelProto::SentencePiece::Type type) -> bool {
if (!dup.insert(w).second) {
LOG(ERROR) << w << " is already defined.";
return false;
}
if (w == kUNK) {
LOG(ERROR) << "<unk> must not be defined with --control_symbols and "
"--user_defined_symbols.";
return false;
}
if (w == kBOS && trainer_spec_.bos_id() >= 0) {
meta_pieces_[trainer_spec_.bos_id()].second = type;
} else if (w == kEOS && trainer_spec_.eos_id() >= 0) {
meta_pieces_[trainer_spec_.eos_id()].second = type;
} else if (w == kPAD && trainer_spec_.pad_id() >= 0) {
meta_pieces_[trainer_spec_.pad_id()].second = type;
} else {
while (meta_pieces_.find(id) != meta_pieces_.end()) ++id;
meta_pieces_[id] = std::make_pair(w, type);
}
return true;
};
for (const auto &w : trainer_spec_.control_symbols()) {
meta_pieces_.emplace_back(w, ModelProto::SentencePiece::CONTROL);
CHECK_OR_RETURN(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL));
}
for (const auto &w : trainer_spec_.user_defined_symbols()) {
meta_pieces_.emplace_back(w, ModelProto::SentencePiece::USER_DEFINED);
CHECK_OR_RETURN(
insert_meta_symbol(w, ModelProto::SentencePiece::USER_DEFINED));
}
return util::OkStatus();

View File

@ -111,8 +111,8 @@ class TrainerInterface {
NormalizerSpec normalizer_spec_;
// Reserved control pieces. e.g., <unk>, <s>, </s>.
// The index corresponds to vocab id.
std::vector<std::pair<std::string, ModelProto::SentencePiece::Type>>
// key is vocab id.
std::map<int, std::pair<std::string, ModelProto::SentencePiece::Type>>
meta_pieces_;
// Detect errors on initialization.

View File

@ -82,10 +82,12 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
}
TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
TrainerSpec trainer_spec;
TrainerSpec base_trainer_spec;
NormalizerSpec normalizer_spec;
trainer_spec.set_model_prefix("model");
trainer_spec.add_input("input");
base_trainer_spec.set_model_prefix("model");
base_trainer_spec.add_input("input");
auto trainer_spec = base_trainer_spec;
// Check default values.
EXPECT_EQ(0, trainer_spec.unk_id());
@ -94,6 +96,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
EXPECT_EQ(-1, trainer_spec.pad_id());
{
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(1);
trainer_spec.set_eos_id(2);
@ -108,6 +111,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
}
{
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(3);
trainer_spec.set_eos_id(2);
@ -122,6 +126,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
}
{
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(-1);
trainer_spec.set_eos_id(1);
@ -134,6 +139,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
}
{
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(-1);
trainer_spec.set_eos_id(-1);
@ -145,6 +151,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
}
{
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(1);
trainer_spec.set_eos_id(2);
@ -167,21 +174,96 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
}
{
// ID is not contiguous.
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(-1);
trainer_spec.set_eos_id(2);
TrainerInterface trainer(trainer_spec, normalizer_spec);
EXPECT_NOT_OK(trainer.status());
EXPECT_OK(trainer.status());
}
{
// UNK is not defined.
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(-1);
trainer_spec.set_bos_id(0);
trainer_spec.set_eos_id(1);
TrainerInterface trainer(trainer_spec, normalizer_spec);
EXPECT_NOT_OK(trainer.status());
EXPECT_FALSE(trainer.status().ok());
}
{
// UNK is out-of-range.
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(640000);
trainer_spec.set_bos_id(0);
trainer_spec.set_eos_id(1);
TrainerInterface trainer(trainer_spec, normalizer_spec);
EXPECT_FALSE(trainer.status().ok());
}
{
auto trainer_spec = base_trainer_spec;
trainer_spec.set_vocab_size(32000);
trainer_spec.set_unk_id(32000 - 1);
trainer_spec.set_bos_id(32000 - 100);
trainer_spec.set_eos_id(32000 - 200);
TrainerInterface trainer(trainer_spec, normalizer_spec);
EXPECT_OK(trainer.status());
}
{
// Cannot assign <unk> as control symbol.
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(1);
trainer_spec.set_eos_id(2);
trainer_spec.add_control_symbols("<unk>");
TrainerInterface trainer(trainer_spec, normalizer_spec);
EXPECT_FALSE(trainer.status().ok());
}
{
// Dup.
auto trainer_spec = base_trainer_spec;
trainer_spec.add_control_symbols("<foo>");
trainer_spec.add_control_symbols("<foo>");
TrainerInterface trainer(trainer_spec, normalizer_spec);
EXPECT_FALSE(trainer.status().ok());
}
{
auto trainer_spec = base_trainer_spec;
trainer_spec.set_unk_id(0);
trainer_spec.set_bos_id(10);
trainer_spec.set_eos_id(20);
trainer_spec.set_pad_id(30);
// <s>, <pad> are treated as USER_DEFIEND,
// </s> is CONTROL.
trainer_spec.add_user_defined_symbols("<s>");
trainer_spec.add_user_defined_symbols("<pad>");
trainer_spec.add_user_defined_symbols("foo");
TrainerInterface trainer(trainer_spec, normalizer_spec);
EXPECT_OK(trainer.status());
EXPECT_EQ(5, trainer.meta_pieces_.size());
EXPECT_EQ("<unk>", trainer.meta_pieces_[0].first);
EXPECT_EQ("<s>", trainer.meta_pieces_[10].first);
EXPECT_EQ("</s>", trainer.meta_pieces_[20].first);
EXPECT_EQ("<pad>", trainer.meta_pieces_[30].first);
EXPECT_EQ("foo", trainer.meta_pieces_[1].first);
EXPECT_EQ(ModelProto::SentencePiece::UNKNOWN,
trainer.meta_pieces_[0].second);
EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED,
trainer.meta_pieces_[10].second);
EXPECT_EQ(ModelProto::SentencePiece::CONTROL,
trainer.meta_pieces_[20].second);
EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED,
trainer.meta_pieces_[30].second);
EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED,
trainer.meta_pieces_[1].second);
}
}

View File

@ -13,8 +13,6 @@
// limitations under the License.!
#include "unigram_model_trainer.h"
#include "builder.h"
#include "normalizer.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
@ -39,7 +37,6 @@ TEST(UnigramTrainerTest, EndToEndTest) {
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("identity");
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
constexpr int kVocabSize = 8000;
trainer_spec.set_vocab_size(kVocabSize);

View File

@ -14,7 +14,6 @@
#include "word_model_trainer.h"
#include "builder.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
@ -46,7 +45,6 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
NormalizerSpec normalizer_spec;
normalizer_spec.set_name("identity");
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
normalizer_spec.set_add_dummy_prefix(true);
Trainer trainer(trainer_spec, normalizer_spec);