mirror of
https://github.com/google/sentencepiece.git
synced 2025-01-06 09:19:12 +03:00
Updated normalizer
This commit is contained in:
parent
4f7af0dfad
commit
4e3bcf1373
@ -1 +0,0 @@
|
||||
20 20 # =>
|
|
@ -25,7 +25,7 @@ setup(name = 'sentencepiece',
|
||||
author_email='taku@google.com',
|
||||
description = 'SentencePiece python wrapper',
|
||||
long_description = long_description,
|
||||
version='0.0.9',
|
||||
version='0.1.0',
|
||||
url = 'https://github.com/google/sentencepiece',
|
||||
license = 'Apache',
|
||||
platforms = 'Unix',
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
#include "bpe_model_trainer.h"
|
||||
|
||||
#include "builder.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
@ -47,7 +46,6 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("identity");
|
||||
normalizer_spec.set_add_dummy_prefix(false);
|
||||
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
|
||||
|
||||
Trainer trainer(trainer_spec, normalizer_spec);
|
||||
trainer.Train();
|
||||
@ -82,7 +80,6 @@ TEST(BPETrainerTest, EndToEndTest) {
|
||||
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("nfkc");
|
||||
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
|
||||
|
||||
constexpr int kVocabSize = 8000;
|
||||
trainer_spec.set_vocab_size(kVocabSize);
|
||||
|
192
src/builder.cc
192
src/builder.cc
@ -13,6 +13,7 @@
|
||||
// limitations under the License.!
|
||||
|
||||
#include "builder.h"
|
||||
#include <functional>
|
||||
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
#include <unicode/errorcode.h>
|
||||
@ -33,10 +34,13 @@
|
||||
namespace sentencepiece {
|
||||
namespace normalizer {
|
||||
namespace {
|
||||
|
||||
constexpr int kMaxUnicode = 0x10FFFF;
|
||||
|
||||
static constexpr char kDefaultNormalizerName[] = "nfkc";
|
||||
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
// Normalize |input| with ICU's normalizer with |mode|.
|
||||
// Normalize `input` with ICU's normalizer with `mode`.
|
||||
Builder::Chars UnicodeNormalize(UNormalizationMode mode,
|
||||
const Builder::Chars &input) {
|
||||
const std::string utf8 = string_util::UnicodeTextToUTF8(input);
|
||||
@ -78,7 +82,7 @@ Builder::Chars ToNFD(const Builder::Chars &input) {
|
||||
}
|
||||
|
||||
// Given an NFKD-normalized string, returns a set of all strings which are
|
||||
// normalized into the same |nfkd|. |norm2orig| is the normalized to
|
||||
// normalized into the same `nfkd`. `norm2orig` is the normalized to
|
||||
// un-normalized character mapping.
|
||||
std::vector<Builder::Chars> ExpandUnnormalized(
|
||||
const Builder::Chars &nfkd,
|
||||
@ -104,8 +108,8 @@ std::vector<Builder::Chars> ExpandUnnormalized(
|
||||
}
|
||||
#endif
|
||||
|
||||
// Normalizes |src| with |chars_map| and returns normalized Chars.
|
||||
// |max_len| specifies the maximum length of the key in |chars_map|.
|
||||
// Normalizes `src` with `chars_map` and returns normalized Chars.
|
||||
// `max_len` specifies the maximum length of the key in `chars_map`.
|
||||
Builder::Chars Normalize(const Builder::CharsMap &chars_map,
|
||||
const Builder::Chars &src, int max_len) {
|
||||
CHECK_GE(max_len, 1);
|
||||
@ -146,7 +150,7 @@ util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
|
||||
CHECK_OR_RETURN(output);
|
||||
CHECK_OR_RETURN(!chars_map.empty());
|
||||
|
||||
LOG(INFO) << "Loading CharsMap of size " << chars_map.size();
|
||||
LOG(INFO) << "Loading CharsMap of size=" << chars_map.size();
|
||||
|
||||
// Aggregates the same target strings to save footprint.
|
||||
std::map<Chars, int> normalized2pos;
|
||||
@ -201,7 +205,59 @@ util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
|
||||
trie.size() * trie.unit_size());
|
||||
*output = Normalizer::EncodePrecompiledCharsMap(trie_blob, normalized);
|
||||
|
||||
LOG(INFO) << "Generated normalizer blob. size= " << output->size();
|
||||
LOG(INFO) << "Generated normalizer blob. size=" << output->size();
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::DecompileCharsMap(StringPiece blob,
|
||||
Builder::CharsMap *chars_map) {
|
||||
CHECK_OR_RETURN(chars_map);
|
||||
chars_map->clear();
|
||||
|
||||
StringPiece trie_blob, normalized;
|
||||
RETURN_IF_ERROR(
|
||||
Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob, &normalized));
|
||||
|
||||
Darts::DoubleArray trie;
|
||||
trie.set_array(const_cast<char *>(trie_blob.data()),
|
||||
trie_blob.size() / trie.unit_size());
|
||||
|
||||
std::string key;
|
||||
std::function<void(size_t, size_t)> traverse;
|
||||
|
||||
// Given a Trie node at `node_pos` and the key position at `key_position`,
|
||||
// Expands children nodes from `node_pos`.
|
||||
// When leaf nodes are found, stores them into `chars_map`.
|
||||
traverse = [&traverse, &key, &trie, &normalized, &chars_map](
|
||||
size_t node_pos, size_t key_pos) -> void {
|
||||
for (int c = 0; c <= 255; ++c) {
|
||||
key.push_back(static_cast<char>(c));
|
||||
size_t copied_node_pos = node_pos;
|
||||
size_t copied_key_pos = key_pos;
|
||||
// Note: `copied_(node|key)_pos` are non-const references.
|
||||
// They store the new positions after node traversal.
|
||||
const Darts::DoubleArray::result_type result = trie.traverse(
|
||||
key.data(), copied_node_pos, copied_key_pos, key.size());
|
||||
if (result >= -1) { // node exists.
|
||||
if (result >= 0) { // has a value after transition.
|
||||
const StringPiece value = normalized.data() + result;
|
||||
Chars key_chars, value_chars;
|
||||
for (const auto c : string_util::UTF8ToUnicodeText(key))
|
||||
key_chars.push_back(c);
|
||||
for (const auto c : string_util::UTF8ToUnicodeText(value))
|
||||
value_chars.push_back(c);
|
||||
(*chars_map)[key_chars] = value_chars;
|
||||
}
|
||||
// Recursively traverse.
|
||||
traverse(copied_node_pos, copied_key_pos);
|
||||
}
|
||||
key.pop_back();
|
||||
}
|
||||
};
|
||||
|
||||
traverse(0, 0);
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
@ -209,6 +265,13 @@ util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
|
||||
// static
|
||||
util::Status Builder::GetPrecompiledCharsMap(const std::string &name,
|
||||
std::string *output) {
|
||||
CHECK_OR_RETURN(output);
|
||||
|
||||
if (name == "identity") {
|
||||
output->clear();
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
std::string result;
|
||||
for (size_t i = 0; i < kNormalizationRules_size; ++i) {
|
||||
const auto *blob = &kNormalizationRules_blob[i];
|
||||
@ -222,33 +285,7 @@ util::Status Builder::GetPrecompiledCharsMap(const std::string &name,
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::PopulateNormalizerSpec(NormalizerSpec *normalizer_spec) {
|
||||
CHECK_OR_RETURN(normalizer_spec);
|
||||
|
||||
if (!normalizer_spec->normalization_rule_tsv().empty()) {
|
||||
CHECK_OR_RETURN(normalizer_spec->precompiled_charsmap().empty())
|
||||
<< "precompiled_charsmap is already defined.";
|
||||
const auto chars_map = normalizer::Builder::BuildMapFromFile(
|
||||
normalizer_spec->normalization_rule_tsv());
|
||||
RETURN_IF_ERROR(CompileCharsMap(
|
||||
chars_map, normalizer_spec->mutable_precompiled_charsmap()));
|
||||
normalizer_spec->set_name("user_defined");
|
||||
} else {
|
||||
if (normalizer_spec->name().empty()) {
|
||||
normalizer_spec->set_name(kDefaultNormalizerName);
|
||||
}
|
||||
if (normalizer_spec->precompiled_charsmap().empty()) {
|
||||
RETURN_IF_ERROR(GetPrecompiledCharsMap(
|
||||
normalizer_spec->name(),
|
||||
normalizer_spec->mutable_precompiled_charsmap()));
|
||||
}
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
Builder::CharsMap Builder::BuildNFKCMap() {
|
||||
util::Status Builder::BuildNFKCMap(CharsMap *chars_map) {
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
LOG(INFO) << "Running BuildNFKCMap";
|
||||
|
||||
@ -286,7 +323,7 @@ Builder::CharsMap Builder::BuildNFKCMap() {
|
||||
if (nfkc == nfkd) {
|
||||
continue;
|
||||
}
|
||||
// Expand all possible sequences which are normalized into the same |nfkd|.
|
||||
// Expand all possible sequences which are normalized into the same `nfkd`.
|
||||
for (const auto &nfkd_orig : ExpandUnnormalized(nfkd, norm2orig)) {
|
||||
if (nfkd_orig != nfkc) {
|
||||
nfkc_map[nfkd_orig] = nfkc;
|
||||
@ -294,61 +331,94 @@ Builder::CharsMap Builder::BuildNFKCMap() {
|
||||
}
|
||||
}
|
||||
|
||||
return RemoveRedundantMap(nfkc_map);
|
||||
RETURN_IF_ERROR(RemoveRedundantMap(&nfkc_map));
|
||||
*chars_map = std::move(nfkc_map);
|
||||
|
||||
#else
|
||||
LOG(FATAL) << "NFKC compile is not enabled."
|
||||
LOG(ERROR) << "NFKC compile is not enabled."
|
||||
<< " rebuild with ./configure --enable-nfkc-compile";
|
||||
return {};
|
||||
#endif
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
Builder::CharsMap Builder::BuildIdentityMap() {
|
||||
// Adds one dummy entry since empty rule is not allowed.
|
||||
const CharsMap result = {{{0x0020}, {0x0020}}};
|
||||
return result;
|
||||
}
|
||||
|
||||
// static
|
||||
Builder::CharsMap Builder::BuildMapFromFile(StringPiece filename) {
|
||||
util::Status Builder::LoadCharsMap(StringPiece filename, CharsMap *chars_map) {
|
||||
LOG(INFO) << "Loading maping file: " << filename.data();
|
||||
CHECK_OR_RETURN(chars_map);
|
||||
|
||||
io::InputBuffer input(filename);
|
||||
RETURN_IF_ERROR(input.status());
|
||||
|
||||
std::string line;
|
||||
CharsMap chars_map;
|
||||
chars_map->clear();
|
||||
while (input.ReadLine(&line)) {
|
||||
const auto fields = string_util::SplitPiece(line, "\t");
|
||||
CHECK_GE(fields.size(), 2);
|
||||
std::vector<char32> src, trg;
|
||||
for (const auto &s : string_util::SplitPiece(fields[0], " ")) {
|
||||
for (auto &s : string_util::SplitPiece(fields[0], " ")) {
|
||||
s.Consume("U+");
|
||||
src.push_back(string_util::HexToInt<char32>(s));
|
||||
}
|
||||
for (const auto &s : string_util::SplitPiece(fields[1], " ")) {
|
||||
for (auto &s : string_util::SplitPiece(fields[1], " ")) {
|
||||
s.Consume("U+");
|
||||
trg.push_back(string_util::HexToInt<char32>(s));
|
||||
}
|
||||
CHECK(!src.empty());
|
||||
CHECK(!trg.empty());
|
||||
chars_map[src] = trg;
|
||||
(*chars_map)[src] = trg;
|
||||
}
|
||||
return chars_map;
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
Builder::CharsMap Builder::RemoveRedundantMap(const CharsMap &chars_map) {
|
||||
CharsMap new_chars_map;
|
||||
util::Status Builder::SaveCharsMap(StringPiece filename,
|
||||
const Builder::CharsMap &chars_map) {
|
||||
io::OutputBuffer output(filename);
|
||||
RETURN_IF_ERROR(output.status());
|
||||
|
||||
for (const auto &c : chars_map) {
|
||||
std::vector<std::string> src, trg;
|
||||
string_util::UnicodeText srcu, trgu;
|
||||
for (char32 v : c.first) {
|
||||
src.push_back(string_util::IntToHex(v));
|
||||
srcu.push_back(v);
|
||||
}
|
||||
for (char32 v : c.second) {
|
||||
trg.push_back(string_util::IntToHex(v));
|
||||
trgu.push_back(v);
|
||||
}
|
||||
std::string line = string_util::Join(src, " ") + "\t" +
|
||||
string_util::Join(trg, " ") + "\t# " +
|
||||
string_util::UnicodeTextToUTF8(c.first) + " => " +
|
||||
string_util::UnicodeTextToUTF8(c.second);
|
||||
line = string_util::StringReplace(line, "\n", " ", true);
|
||||
line = string_util::StringReplace(line, "\r", " ", true);
|
||||
output.WriteLine(line);
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::RemoveRedundantMap(CharsMap *chars_map) {
|
||||
CHECK_OR_RETURN(chars_map);
|
||||
|
||||
CharsMap new_chars_map;
|
||||
size_t max_len = 0;
|
||||
for (const auto &p : chars_map) {
|
||||
for (const auto &p : *chars_map) {
|
||||
max_len = std::max(p.first.size(), max_len);
|
||||
if (p.first.size() == 1) {
|
||||
new_chars_map.insert(p);
|
||||
}
|
||||
}
|
||||
CHECK_GT(max_len, 0);
|
||||
CHECK_GT_OR_RETURN(max_len, 0);
|
||||
|
||||
// Checks whether the rules with size of |len| can be normalized by
|
||||
// Checks whether the rules with size of `len` can be normalized by
|
||||
// the rules with size of [1 .. len - 1].
|
||||
for (size_t len = 2; len <= max_len; ++len) {
|
||||
for (const auto &p : chars_map) {
|
||||
for (const auto &p : *chars_map) {
|
||||
if (p.first.size() == len &&
|
||||
p.second != Normalize(new_chars_map, p.first, len - 1)) {
|
||||
new_chars_map.insert(p);
|
||||
@ -356,12 +426,14 @@ Builder::CharsMap Builder::RemoveRedundantMap(const CharsMap &chars_map) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all characters in |chars_map| are normalized by |new_chars_map|.
|
||||
for (const auto &p : chars_map) {
|
||||
CHECK_EQ(p.second, Normalize(new_chars_map, p.first, max_len));
|
||||
// Verify all characters in `chars_map` are normalized by `new_chars_map`.
|
||||
for (const auto &p : *chars_map) {
|
||||
CHECK_EQ_OR_RETURN(p.second, Normalize(new_chars_map, p.first, max_len));
|
||||
}
|
||||
|
||||
return new_chars_map;
|
||||
*chars_map = std::move(new_chars_map);
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
} // namespace normalizer
|
||||
} // namespace sentencepiece
|
||||
|
@ -45,14 +45,13 @@ class Builder {
|
||||
static util::Status CompileCharsMap(const CharsMap &chars_map,
|
||||
std::string *output);
|
||||
|
||||
// Decompiles `blob` into `chars_map`.
|
||||
static util::Status DecompileCharsMap(StringPiece blob, CharsMap *chars_map);
|
||||
|
||||
// Returns a pre-compiled binary index with `name`.
|
||||
static util::Status GetPrecompiledCharsMap(const std::string &name,
|
||||
std::string *output);
|
||||
|
||||
// Populates necessary fields (precompiled_charmap) from
|
||||
// `name` or `normalization_rule_tsv` fields in `normalizer_spec`.
|
||||
static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec);
|
||||
|
||||
// Makes a normalization mapping based on NFKC.
|
||||
//
|
||||
// Note that Normalizer/Builder classes do not support
|
||||
@ -88,16 +87,17 @@ class Builder {
|
||||
// normalizer is the goal of SentencePiece.
|
||||
//
|
||||
// TODO(taku): Make NFC, NFD, and NFKD mapping if necessary.
|
||||
static CharsMap BuildNFKCMap();
|
||||
|
||||
// Returns identity mapping, which dose not perform any normalization.
|
||||
static CharsMap BuildIdentityMap();
|
||||
static util::Status BuildNFKCMap(CharsMap *chars_map);
|
||||
|
||||
// Builds Chars map save in `filename`.
|
||||
// Format:
|
||||
// src_uchar1 src_uchar2 ... <tab> trg_uchar1 trg_uchar2...
|
||||
// (src|trg)_ucharX must be a hex of UCS4.
|
||||
static CharsMap BuildMapFromFile(StringPiece filename);
|
||||
// (src|trg)_ucharX must be a hex of Unicode code point.
|
||||
static util::Status LoadCharsMap(StringPiece filename, CharsMap *chars_map);
|
||||
|
||||
// Saves Chars map to `filename` as TSV.
|
||||
static util::Status SaveCharsMap(StringPiece filename,
|
||||
const CharsMap &chars_map);
|
||||
|
||||
private:
|
||||
FRIEND_TEST(BuilderTest, RemoveRedundantMapTest);
|
||||
@ -105,7 +105,7 @@ class Builder {
|
||||
// Removes redundant rules from `chars_map`.
|
||||
// When char_maps have "aa" => "bb" and "a" => "b", the first
|
||||
// rule is not necessary since the second rule can cover the first rule.
|
||||
static CharsMap RemoveRedundantMap(const CharsMap &chars_map);
|
||||
static util::Status RemoveRedundantMap(CharsMap *chars_map);
|
||||
};
|
||||
} // namespace normalizer
|
||||
} // namespace sentencepiece
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "builder.h"
|
||||
#include "common.h"
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
|
||||
@ -33,12 +34,12 @@ TEST(BuilderTest, RemoveRedundantMapTest) {
|
||||
chars_map[{0x0061, 0x0062}] = {0x0041, 0x0042};
|
||||
chars_map[{0x0061, 0x0062, 0x0063}] = {0x0043, 0x0042, 0x0041};
|
||||
|
||||
const auto new_chars_map = Builder::RemoveRedundantMap(chars_map);
|
||||
EXPECT_EQ(3, new_chars_map.size());
|
||||
EXPECT_EQ(new_chars_map.end(), new_chars_map.find({0x0061, 0x0062}));
|
||||
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0061}));
|
||||
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0062}));
|
||||
EXPECT_NE(new_chars_map.end(), new_chars_map.find({0x0061, 0x0062, 0x0063}));
|
||||
EXPECT_OK(Builder::RemoveRedundantMap(&chars_map));
|
||||
EXPECT_EQ(3, chars_map.size());
|
||||
EXPECT_EQ(chars_map.end(), chars_map.find({0x0061, 0x0062}));
|
||||
EXPECT_NE(chars_map.end(), chars_map.find({0x0061}));
|
||||
EXPECT_NE(chars_map.end(), chars_map.find({0x0062}));
|
||||
EXPECT_NE(chars_map.end(), chars_map.find({0x0061, 0x0062, 0x0063}));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, GetPrecompiledCharsMapWithInvalidNameTest) {
|
||||
@ -47,25 +48,19 @@ TEST(BuilderTest, GetPrecompiledCharsMapWithInvalidNameTest) {
|
||||
EXPECT_NOT_OK(Builder::GetPrecompiledCharsMap("__UNKNOWN__", &output));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, BuildIdentityMapTest) {
|
||||
const auto m = Builder::BuildIdentityMap();
|
||||
EXPECT_EQ(1, m.size());
|
||||
}
|
||||
|
||||
TEST(BuilderTest, BuildNFKCMapTest) {
|
||||
Builder::CharsMap chars_map;
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
const auto m = Builder::BuildNFKCMap();
|
||||
EXPECT_TRUE(!m.empty());
|
||||
EXPECT_OK(Builder::BuildNFKCMap(&chars_map));
|
||||
EXPECT_TRUE(!chars_map.empty());
|
||||
#else
|
||||
EXPECT_DEATH(Builder::BuildNFKCMap());
|
||||
// EXPECT_DEATH(Builder::BuildNFKCMap(&chars_map));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(BuilderTest, GetPrecompiledCharsMapTest) {
|
||||
{
|
||||
NormalizerSpec spec;
|
||||
spec.set_name("nfkc");
|
||||
EXPECT_OK(Builder::PopulateNormalizerSpec(&spec));
|
||||
const NormalizerSpec spec = SentencePieceTrainer::GetNormalizerSpec("nfkc");
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
|
||||
EXPECT_EQ(WS "(株)", normalizer.Normalize("㈱"));
|
||||
@ -73,9 +68,9 @@ TEST(BuilderTest, GetPrecompiledCharsMapTest) {
|
||||
}
|
||||
|
||||
{
|
||||
NormalizerSpec spec;
|
||||
spec.set_name("identity");
|
||||
EXPECT_OK(Builder::PopulateNormalizerSpec(&spec));
|
||||
const NormalizerSpec spec =
|
||||
SentencePieceTrainer::GetNormalizerSpec("identity");
|
||||
EXPECT_TRUE(spec.precompiled_charsmap().empty());
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
|
||||
EXPECT_EQ(WS "㈱", normalizer.Normalize("㈱"));
|
||||
@ -99,6 +94,11 @@ TEST(BuilderTest, CompileCharsMap) {
|
||||
NormalizerSpec spec;
|
||||
EXPECT_OK(
|
||||
Builder::CompileCharsMap(chars_map, spec.mutable_precompiled_charsmap()));
|
||||
Builder::CharsMap decompiled_chars_map;
|
||||
EXPECT_OK(Builder::DecompileCharsMap(spec.precompiled_charsmap(),
|
||||
&decompiled_chars_map));
|
||||
EXPECT_EQ(chars_map, decompiled_chars_map);
|
||||
|
||||
spec.set_add_dummy_prefix(false);
|
||||
const Normalizer normalizer(spec);
|
||||
|
||||
@ -112,12 +112,30 @@ TEST(BuilderTest, CompileCharsMap) {
|
||||
EXPECT_EQ("ABCabcD", normalizer.Normalize("abcあいうd"));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, BuildMapFromFileTest) {
|
||||
const auto cmap = Builder::BuildMapFromFile("../data/nfkc.tsv");
|
||||
std::string expected, precompiled;
|
||||
EXPECT_OK(Builder::CompileCharsMap(cmap, &precompiled));
|
||||
EXPECT_OK(Builder::GetPrecompiledCharsMap("nfkc", &expected));
|
||||
EXPECT_EQ(expected, precompiled);
|
||||
TEST(BuilderTest, LoadCharsMapTest) {
|
||||
Builder::CharsMap chars_map;
|
||||
EXPECT_OK(Builder::LoadCharsMap("../data/nfkc.tsv", &chars_map));
|
||||
|
||||
std::string precompiled, expected;
|
||||
EXPECT_OK(Builder::CompileCharsMap(chars_map, &precompiled));
|
||||
|
||||
// Round-trip.
|
||||
Builder::CharsMap decompiled_chars_map;
|
||||
EXPECT_OK(Builder::DecompileCharsMap(precompiled, &decompiled_chars_map));
|
||||
EXPECT_EQ(chars_map, decompiled_chars_map);
|
||||
|
||||
test::ScopedTempFile output_tsv("output.tsv");
|
||||
EXPECT_OK(Builder::SaveCharsMap(output_tsv.filename(), chars_map));
|
||||
|
||||
Builder::CharsMap saved_chars_map;
|
||||
EXPECT_OK(Builder::LoadCharsMap(output_tsv.filename(), &saved_chars_map));
|
||||
EXPECT_EQ(chars_map, saved_chars_map);
|
||||
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
Builder::CharsMap nfkc_map;
|
||||
EXPECT_OK(Builder::BuildNFKCMap(&nfkc_map));
|
||||
EXPECT_OK(Builder::CompileCharsMap(nfkc_map, &expected));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
|
||||
@ -129,7 +147,7 @@ TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
|
||||
chars_map[keys] = {'b'};
|
||||
}
|
||||
std::string output;
|
||||
EXPECT_NOT_OK(Builder::CompileCharsMap(chars_map, &output));
|
||||
EXPECT_FALSE(Builder::CompileCharsMap(chars_map, &output).ok());
|
||||
}
|
||||
|
||||
} // namespace normalizer
|
||||
|
@ -13,7 +13,6 @@
|
||||
// limitations under the License.!
|
||||
|
||||
#include "char_model_trainer.h"
|
||||
#include "builder.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
@ -45,7 +44,6 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
|
||||
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("identity");
|
||||
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
|
||||
|
||||
Trainer trainer(trainer_spec, normalizer_spec);
|
||||
trainer.Train();
|
||||
|
@ -18,34 +18,17 @@
|
||||
#include <string>
|
||||
#include "builder.h"
|
||||
#include "flags.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "stringpiece.h"
|
||||
#include "util.h"
|
||||
|
||||
using sentencepiece::normalizer::Builder;
|
||||
using sentencepiece::util::Status;
|
||||
|
||||
DEFINE_bool(output_precompiled_header, false, "make normalization_rule.h file");
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
void WriteTSV(const Builder::CharsMap &cmap, StringPiece filename) {
|
||||
sentencepiece::io::OutputBuffer output(filename);
|
||||
for (const auto &c : cmap) {
|
||||
std::vector<std::string> src, trg;
|
||||
for (char32 v : c.first) {
|
||||
src.push_back(string_util::IntToHex(v));
|
||||
}
|
||||
for (char32 v : c.second) {
|
||||
trg.push_back(string_util::IntToHex(v));
|
||||
}
|
||||
std::string line = string_util::Join(src, " ") + "\t" +
|
||||
string_util::Join(trg, " ") + "\t# " +
|
||||
string_util::UnicodeTextToUTF8(c.first) + " => " +
|
||||
string_util::UnicodeTextToUTF8(c.second);
|
||||
line = string_util::StringReplace(line, "\n", " ", true);
|
||||
line = string_util::StringReplace(line, "\r", " ", true);
|
||||
output.WriteLine(line);
|
||||
}
|
||||
}
|
||||
|
||||
std::string ToHexData(StringPiece data) {
|
||||
const char *begin = data.data();
|
||||
@ -81,9 +64,9 @@ std::string ToHexData(StringPiece data) {
|
||||
int main(int argc, char **argv) {
|
||||
sentencepiece::flags::ParseCommandLineFlags(argc, argv);
|
||||
|
||||
const std::vector<std::pair<std::string, std::function<Builder::CharsMap()>>>
|
||||
kRuleList = {{"nfkc", Builder::BuildNFKCMap},
|
||||
{"identity", Builder::BuildIdentityMap}};
|
||||
const std::vector<
|
||||
std::pair<std::string, std::function<Status(Builder::CharsMap *)>>>
|
||||
kRuleList = {{"nfkc", Builder::BuildNFKCMap}};
|
||||
|
||||
constexpr char kHeader[] =
|
||||
R"(#ifndef NORMALIZATION_RULE_H_
|
||||
@ -107,13 +90,18 @@ constexpr BinaryBlob kNormalizationRules_blob[] = {)";
|
||||
os << kHeader;
|
||||
|
||||
for (const auto &p : kRuleList) {
|
||||
const auto normalized_map = p.second();
|
||||
Builder::CharsMap normalized_map;
|
||||
CHECK_OK(p.second(&normalized_map));
|
||||
|
||||
// Write Header.
|
||||
std::string index;
|
||||
CHECK_OK(Builder::CompileCharsMap(normalized_map, &index));
|
||||
os << "{ \"" << p.first << "\", " << index.size() << ",\n";
|
||||
os << sentencepiece::ToHexData(index);
|
||||
os << " },";
|
||||
sentencepiece::WriteTSV(normalized_map, p.first + ".tsv");
|
||||
|
||||
// Write TSV file.
|
||||
CHECK_OK(Builder::SaveCharsMap(p.first + ".tsv", normalized_map));
|
||||
}
|
||||
|
||||
os << "};\n";
|
||||
@ -124,6 +112,7 @@ constexpr BinaryBlob kNormalizationRules_blob[] = {)";
|
||||
if (FLAGS_output_precompiled_header) {
|
||||
constexpr char kPrecompiledHeaderFileName[] = "normalization_rule.h";
|
||||
sentencepiece::io::OutputBuffer output(kPrecompiledHeaderFileName);
|
||||
CHECK_OK(output.status());
|
||||
output.Write(os.str());
|
||||
}
|
||||
|
||||
|
@ -57,14 +57,14 @@ bool ModelInterface::IsUnknown(int id) const {
|
||||
void ModelInterface::InitializePieces(bool enable_user_defined) {
|
||||
pieces_.clear();
|
||||
reserved_id_map_.clear();
|
||||
unk_id_ = 0;
|
||||
unk_id_ = -1;
|
||||
|
||||
for (int i = 0; i < model_proto_->pieces_size(); ++i) {
|
||||
const auto &sp = model_proto_->pieces(i);
|
||||
if (!enable_user_defined &&
|
||||
sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
|
||||
status_ = util::StatusBuilder(util::error::INTERNAL)
|
||||
<< "user defined symbol is not supported.";
|
||||
<< "User defined symbol is not supported.";
|
||||
return;
|
||||
}
|
||||
|
||||
@ -78,8 +78,19 @@ void ModelInterface::InitializePieces(bool enable_user_defined) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sp.type() == ModelProto::SentencePiece::UNKNOWN) unk_id_ = i;
|
||||
if (sp.type() == ModelProto::SentencePiece::UNKNOWN) {
|
||||
if (unk_id_ >= 0) {
|
||||
status_ = util::StatusBuilder(util::error::INTERNAL)
|
||||
<< "unk is already defined.";
|
||||
return;
|
||||
}
|
||||
unk_id_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (unk_id_ == -1)
|
||||
status_ = util::StatusBuilder(util::error::INTERNAL)
|
||||
<< "unk is not defined.";
|
||||
}
|
||||
|
||||
std::vector<StringPiece> SplitIntoWords(StringPiece text) {
|
||||
|
@ -23,36 +23,32 @@ namespace normalizer {
|
||||
|
||||
constexpr int Normalizer::kMaxTrieResultsSize;
|
||||
|
||||
Normalizer::Normalizer(const NormalizerSpec &spec) : spec_(&spec) {
|
||||
Normalizer::Normalizer(const NormalizerSpec &spec)
|
||||
: spec_(&spec), status_(util::OkStatus()) {
|
||||
StringPiece index = spec.precompiled_charsmap();
|
||||
if (index.empty()) {
|
||||
status_ = util::InvalidArgumentError("precompiled_charsmap is empty.");
|
||||
return;
|
||||
LOG(INFO) << "precompiled_charsmap is empty. use identity normalization.";
|
||||
} else {
|
||||
StringPiece trie_blob, normalized;
|
||||
status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized);
|
||||
if (!status_.ok()) return;
|
||||
|
||||
// Reads the body of double array.
|
||||
trie_ = port::MakeUnique<Darts::DoubleArray>();
|
||||
|
||||
// The second arg of set_array is not the size of blob,
|
||||
// but the number of double array units.
|
||||
trie_->set_array(const_cast<char *>(trie_blob.data()),
|
||||
trie_blob.size() / trie_->unit_size());
|
||||
|
||||
normalized_ = normalized.data();
|
||||
}
|
||||
|
||||
StringPiece trie_blob, normalized;
|
||||
status_ = DecodePrecompiledCharsMap(index, &trie_blob, &normalized);
|
||||
if (!status_.ok()) return;
|
||||
|
||||
// Reads the body of double array.
|
||||
trie_ = port::MakeUnique<Darts::DoubleArray>();
|
||||
|
||||
// The second arg of set_array is not the size of blob,
|
||||
// but the number of double array units.
|
||||
trie_->set_array(const_cast<char *>(trie_blob.data()),
|
||||
trie_blob.size() / trie_->unit_size());
|
||||
|
||||
normalized_ = normalized.data();
|
||||
}
|
||||
|
||||
Normalizer::~Normalizer() {}
|
||||
|
||||
util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
|
||||
std::vector<size_t> *norm_to_orig) const {
|
||||
if (trie_ == nullptr || normalized_ == nullptr) {
|
||||
return util::InternalError("Normalizer model is not available.");
|
||||
}
|
||||
|
||||
norm_to_orig->clear();
|
||||
normalized->clear();
|
||||
|
||||
@ -60,6 +56,8 @@ util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(status());
|
||||
|
||||
int consumed = 0;
|
||||
|
||||
// Ignores heading space.
|
||||
@ -144,7 +142,7 @@ util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
|
||||
const StringPiece space = spec_->escape_whitespaces() ? kSpaceSymbol : " ";
|
||||
while (string_util::EndsWith(*normalized, space)) {
|
||||
const int length = normalized->size() - space.size();
|
||||
if (length < 0) return util::InternalError("length < 0");
|
||||
CHECK_GE_OR_RETURN(length, 0);
|
||||
consumed = (*norm_to_orig)[length];
|
||||
normalized->resize(length);
|
||||
norm_to_orig->resize(length);
|
||||
@ -153,9 +151,7 @@ util::Status Normalizer::Normalize(StringPiece input, std::string *normalized,
|
||||
|
||||
norm_to_orig->push_back(consumed);
|
||||
|
||||
if (norm_to_orig->size() != normalized->size() + 1) {
|
||||
return util::InternalError("norm_to_org and normalized are inconsistent");
|
||||
}
|
||||
CHECK_EQ_OR_RETURN(norm_to_orig->size(), normalized->size() + 1);
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
@ -173,25 +169,27 @@ std::pair<StringPiece, int> Normalizer::NormalizePrefix(
|
||||
|
||||
if (input.empty()) return result;
|
||||
|
||||
// Allocates trie_results in stack, which makes the encoding speed 36% faster.
|
||||
// (38k sentences/sec => 60k sentences/sec).
|
||||
// Builder checks that the result size never exceeds kMaxTrieResultsSize.
|
||||
// This array consumes 0.5kByte in stack, which is less than
|
||||
// default stack frames (16kByte).
|
||||
Darts::DoubleArray::result_pair_type
|
||||
trie_results[Normalizer::kMaxTrieResultsSize];
|
||||
|
||||
const size_t num_nodes =
|
||||
trie_->commonPrefixSearch(input.data(), trie_results,
|
||||
Normalizer::kMaxTrieResultsSize, input.size());
|
||||
|
||||
// Finds the longest rule.
|
||||
size_t longest_length = 0;
|
||||
int longest_value = 0;
|
||||
for (size_t k = 0; k < num_nodes; ++k) {
|
||||
if (longest_length == 0 || trie_results[k].length > longest_length) {
|
||||
longest_length = trie_results[k].length; // length of prefix
|
||||
longest_value = trie_results[k].value; // pointer to |normalized_|.
|
||||
|
||||
if (trie_ != nullptr) {
|
||||
// Allocates trie_results in stack, which makes the encoding speed 36%
|
||||
// faster. (38k sentences/sec => 60k sentences/sec). Builder checks that the
|
||||
// result size never exceeds kMaxTrieResultsSize. This array consumes
|
||||
// 0.5kByte in stack, which is less than default stack frames (16kByte).
|
||||
Darts::DoubleArray::result_pair_type
|
||||
trie_results[Normalizer::kMaxTrieResultsSize];
|
||||
|
||||
const size_t num_nodes = trie_->commonPrefixSearch(
|
||||
input.data(), trie_results, Normalizer::kMaxTrieResultsSize,
|
||||
input.size());
|
||||
|
||||
// Finds the longest rule.
|
||||
for (size_t k = 0; k < num_nodes; ++k) {
|
||||
if (longest_length == 0 || trie_results[k].length > longest_length) {
|
||||
longest_length = trie_results[k].length; // length of prefix
|
||||
longest_value = trie_results[k].value; // pointer to |normalized_|.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -239,7 +237,7 @@ util::Status Normalizer::DecodePrecompiledCharsMap(StringPiece blob,
|
||||
!string_util::DecodePOD<uint32>(
|
||||
StringPiece(blob.data(), sizeof(trie_blob_size)), &trie_blob_size) ||
|
||||
trie_blob_size >= blob.size()) {
|
||||
return util::InternalError("Trie blob is broken.");
|
||||
return util::InternalError("Blob for normalization rule is broken.");
|
||||
}
|
||||
|
||||
blob.remove_prefix(sizeof(trie_blob_size));
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
#include "normalizer.h"
|
||||
#include "builder.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
|
||||
@ -23,20 +24,14 @@ namespace {
|
||||
// Space symbol
|
||||
#define WS "\xe2\x96\x81"
|
||||
|
||||
// Replacement char
|
||||
#define RC "\xEF\xBF\xBD"
|
||||
|
||||
NormalizerSpec MakeDefaultSpec() {
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("nfkc");
|
||||
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
|
||||
return normalizer_spec;
|
||||
return SentencePieceTrainer::GetNormalizerSpec("nfkc");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(NormalizerTest, NormalizeErrorTest) {
|
||||
NormalizerSpec spec;
|
||||
Normalizer normalizer(spec);
|
||||
EXPECT_NOT_OK(normalizer.Normalize("test", nullptr, nullptr));
|
||||
}
|
||||
|
||||
TEST(NormalizerTest, NormalizeTest) {
|
||||
auto spec = MakeDefaultSpec();
|
||||
const Normalizer normalizer(spec);
|
||||
@ -211,6 +206,44 @@ TEST(NormalizeTest, NomalizeWithSpaceContainedRules) {
|
||||
EXPECT_EQ(" A F G ", normalizer.Normalize("ad"));
|
||||
EXPECT_EQ(" A F G B", normalizer.Normalize("adb"));
|
||||
}
|
||||
|
||||
// Added several corner cases around spaces.
|
||||
struct SpacePattern {
|
||||
bool add_dummy_prefix;
|
||||
bool remove_extra_whitespaces;
|
||||
bool escape_whitespaces;
|
||||
const char *input;
|
||||
const char *expected;
|
||||
};
|
||||
|
||||
constexpr SpacePattern kSpacePatternData[] = {
|
||||
{false, false, false, WS, WS}, {false, false, true, WS, WS},
|
||||
{false, true, false, WS, WS}, {false, true, true, WS, ""},
|
||||
{true, false, false, WS, " " WS}, {true, false, true, WS, WS WS},
|
||||
{true, true, false, WS, " " WS}, {true, true, true, WS, ""},
|
||||
{false, false, false, " ", " "}, {false, false, true, " ", WS},
|
||||
{false, true, false, " ", ""}, {false, true, true, " ", ""},
|
||||
{true, false, false, " ", " "}, {true, false, true, " ", WS WS},
|
||||
{true, true, false, " ", ""}, {true, true, true, " ", ""}};
|
||||
|
||||
for (const auto &c : kSpacePatternData) {
|
||||
spec.set_add_dummy_prefix(c.add_dummy_prefix);
|
||||
spec.set_remove_extra_whitespaces(c.remove_extra_whitespaces);
|
||||
spec.set_escape_whitespaces(c.escape_whitespaces);
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ(c.expected, normalizer.Normalize(c.input));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NormalizerTest, NormalizeReplacementChar) {
|
||||
auto spec = MakeDefaultSpec();
|
||||
spec.set_add_dummy_prefix(false);
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ("abc" RC "xy", normalizer.Normalize("abc\x80xy"));
|
||||
EXPECT_EQ("abc" RC, normalizer.Normalize("abc\xc3"));
|
||||
EXPECT_EQ("ab" RC RC "xy", normalizer.Normalize("ab\xe3\x81xy"));
|
||||
EXPECT_EQ("a" RC RC RC "xy", normalizer.Normalize("a\xf3\x81\x81xy"));
|
||||
EXPECT_EQ("ab" RC RC "xy", normalizer.Normalize("ab\xc0\x82xy"));
|
||||
}
|
||||
|
||||
TEST(NormalizerTest, NormalizeFullTest) {
|
||||
@ -310,6 +343,12 @@ TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest) {
|
||||
TEST(NormalizerTest, StatusTest) {
|
||||
NormalizerSpec spec;
|
||||
{
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_OK(normalizer.status()); // fallback to identity.
|
||||
}
|
||||
|
||||
{
|
||||
spec.set_precompiled_charsmap("x");
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_FALSE(normalizer.status().ok());
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece.pb.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "stringpiece.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
@ -100,10 +101,7 @@ std::vector<std::string> GetSpVec(const SentencePieceText &spt) {
|
||||
}
|
||||
|
||||
NormalizerSpec MakeDefaultNormalizerSpec() {
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("nfkc");
|
||||
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
|
||||
return normalizer_spec;
|
||||
return SentencePieceTrainer::GetNormalizerSpec("nfkc");
|
||||
}
|
||||
|
||||
TEST(SentencepieceProcessorTest, StatusTest) {
|
||||
|
@ -25,6 +25,9 @@
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
static constexpr char kDefaultNormalizerName[] = "nfkc";
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
|
||||
@ -36,13 +39,21 @@ util::Status SentencePieceTrainer::Train(const TrainerSpec &trainer_spec) {
|
||||
util::Status SentencePieceTrainer::Train(
|
||||
const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec) {
|
||||
auto copied_normalizer_spec = normalizer_spec;
|
||||
RETURN_IF_ERROR(
|
||||
normalizer::Builder::PopulateNormalizerSpec(&copied_normalizer_spec));
|
||||
|
||||
RETURN_IF_ERROR(PopulateNormalizerSpec(&copied_normalizer_spec));
|
||||
auto trainer = TrainerFactory::Create(trainer_spec, copied_normalizer_spec);
|
||||
return trainer->Train();
|
||||
}
|
||||
|
||||
// static
|
||||
NormalizerSpec SentencePieceTrainer::GetNormalizerSpec(
|
||||
const std::string &name) {
|
||||
NormalizerSpec spec;
|
||||
spec.set_name(name);
|
||||
CHECK_OK(normalizer::Builder::GetPrecompiledCharsMap(
|
||||
spec.name(), spec.mutable_precompiled_charsmap()));
|
||||
return spec;
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::SetProtoField(
|
||||
const std::string &field_name, const std::string &value,
|
||||
@ -161,8 +172,36 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs(
|
||||
util::Status SentencePieceTrainer::Train(const std::string &args) {
|
||||
TrainerSpec trainer_spec;
|
||||
NormalizerSpec normalizer_spec;
|
||||
CHECK_OK(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
|
||||
RETURN_IF_ERROR(MergeSpecsFromArgs(args, &trainer_spec, &normalizer_spec));
|
||||
return Train(trainer_spec, normalizer_spec);
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status SentencePieceTrainer::PopulateNormalizerSpec(
|
||||
NormalizerSpec *normalizer_spec) {
|
||||
CHECK_OR_RETURN(normalizer_spec);
|
||||
|
||||
if (!normalizer_spec->normalization_rule_tsv().empty()) {
|
||||
CHECK_OR_RETURN(normalizer_spec->precompiled_charsmap().empty())
|
||||
<< "precompiled_charsmap is already defined.";
|
||||
normalizer::Builder::CharsMap chars_map;
|
||||
RETURN_IF_ERROR(normalizer::Builder::LoadCharsMap(
|
||||
normalizer_spec->normalization_rule_tsv(), &chars_map));
|
||||
RETURN_IF_ERROR(normalizer::Builder::CompileCharsMap(
|
||||
chars_map, normalizer_spec->mutable_precompiled_charsmap()));
|
||||
normalizer_spec->set_name("user_defined");
|
||||
} else {
|
||||
if (normalizer_spec->name().empty()) {
|
||||
normalizer_spec->set_name(kDefaultNormalizerName);
|
||||
}
|
||||
if (normalizer_spec->precompiled_charsmap().empty()) {
|
||||
RETURN_IF_ERROR(normalizer::Builder::GetPrecompiledCharsMap(
|
||||
normalizer_spec->name(),
|
||||
normalizer_spec->mutable_precompiled_charsmap()));
|
||||
}
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
@ -45,6 +45,15 @@ class SentencePieceTrainer {
|
||||
// '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
|
||||
static util::Status Train(const std::string &args);
|
||||
|
||||
// Handy function to make a normalizer spec from the pre-compiled
|
||||
// normalization name. Do not use this method in production as it crashes
|
||||
// when `name` is invalid. Useful for unittesting.
|
||||
static NormalizerSpec GetNormalizerSpec(const std::string &name);
|
||||
|
||||
// Populates necessary fields (precompiled_charmap) from
|
||||
// `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`.
|
||||
static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec);
|
||||
|
||||
// Overrides `trainer_spec` and `normalizer_spec` with the
|
||||
// command-line string in `args`.
|
||||
static util::Status MergeSpecsFromArgs(const std::string &args,
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "sentencepiece.pb.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "util.h"
|
||||
|
||||
DEFINE_string(model, "", "Model file name");
|
||||
@ -32,24 +33,30 @@ DEFINE_string(normalization_rule_tsv, "", "Normalization rule TSV file. ");
|
||||
DEFINE_bool(remove_extra_whitespaces, true, "Remove extra whitespaces");
|
||||
DEFINE_string(output, "", "Output filename");
|
||||
|
||||
using sentencepiece::ModelProto;
|
||||
using sentencepiece::NormalizerSpec;
|
||||
using sentencepiece::SentencePieceProcessor;
|
||||
using sentencepiece::SentencePieceTrainer;
|
||||
using sentencepiece::normalizer::Builder;
|
||||
using sentencepiece::normalizer::Normalizer;
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
std::vector<std::string> rest_args;
|
||||
sentencepiece::flags::ParseCommandLineFlags(argc, argv, &rest_args);
|
||||
|
||||
sentencepiece::NormalizerSpec spec;
|
||||
NormalizerSpec spec;
|
||||
|
||||
if (!FLAGS_model.empty()) {
|
||||
sentencepiece::SentencePieceProcessor sp;
|
||||
ModelProto model_proto;
|
||||
SentencePieceProcessor sp;
|
||||
CHECK_OK(sp.Load(FLAGS_model));
|
||||
spec = sp.model_proto().normalizer_spec();
|
||||
} else if (!FLAGS_normalization_rule_tsv.empty()) {
|
||||
spec.set_normalization_rule_tsv(FLAGS_normalization_rule_tsv);
|
||||
CHECK_OK(Builder::PopulateNormalizerSpec(&spec));
|
||||
CHECK_OK(SentencePieceTrainer::PopulateNormalizerSpec(&spec));
|
||||
} else if (!FLAGS_normalization_rule_name.empty()) {
|
||||
spec.set_name(FLAGS_normalization_rule_name);
|
||||
CHECK_OK(Builder::PopulateNormalizerSpec(&spec));
|
||||
CHECK_OK(SentencePieceTrainer::PopulateNormalizerSpec(&spec));
|
||||
} else {
|
||||
LOG(FATAL) << "Sets --model, normalization_rule_tsv, or "
|
||||
"normalization_rule_name flag.";
|
||||
@ -62,7 +69,7 @@ int main(int argc, char *argv[]) {
|
||||
spec.set_remove_extra_whitespaces(FLAGS_remove_extra_whitespaces);
|
||||
}
|
||||
|
||||
sentencepiece::normalizer::Normalizer normalizer(spec);
|
||||
const Normalizer normalizer(spec);
|
||||
sentencepiece::io::OutputBuffer output(FLAGS_output);
|
||||
CHECK_OK(output.status());
|
||||
|
||||
|
@ -283,29 +283,37 @@ void TrainerInterface::SplitSentencesByWhitespace() {
|
||||
}
|
||||
|
||||
util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
|
||||
RETURN_IF_ERROR(status());
|
||||
|
||||
// Duplicated sentencepiece is not allowed.
|
||||
std::unordered_set<std::string> dup;
|
||||
std::set<std::string> dup;
|
||||
|
||||
#define CHECK_PIECE(piece) \
|
||||
CHECK_OR_RETURN(string_util::IsStructurallyValid(piece)); \
|
||||
CHECK_OR_RETURN(!piece.empty()); \
|
||||
CHECK_OR_RETURN(dup.insert(piece).second) << piece << " is already defined";
|
||||
|
||||
for (const auto &w : meta_pieces_) {
|
||||
auto *sp = model_proto->add_pieces();
|
||||
sp->set_piece(w.first);
|
||||
sp->set_type(w.second);
|
||||
sp->set_score(0.0);
|
||||
CHECK_NE_OR_RETURN(ModelProto::SentencePiece::NORMAL, sp->type());
|
||||
CHECK_PIECE(sp->piece());
|
||||
size_t fid = 0;
|
||||
for (int id = 0; id < trainer_spec_.vocab_size(); ++id) {
|
||||
const auto it = meta_pieces_.find(id);
|
||||
if (it != meta_pieces_.end()) {
|
||||
auto *sp = model_proto->add_pieces();
|
||||
sp->set_piece(it->second.first);
|
||||
sp->set_type(it->second.second);
|
||||
sp->set_score(0.0);
|
||||
CHECK_EQ_OR_RETURN(model_proto->pieces_size() - 1, it->first);
|
||||
CHECK_NE_OR_RETURN(ModelProto::SentencePiece::NORMAL, sp->type());
|
||||
CHECK_PIECE(sp->piece());
|
||||
} else if (fid < final_pieces_.size()) {
|
||||
const auto &w = final_pieces_[fid++];
|
||||
auto *sp = model_proto->add_pieces();
|
||||
sp->set_piece(w.first);
|
||||
sp->set_score(w.second);
|
||||
CHECK_PIECE(sp->piece());
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &w : final_pieces_) {
|
||||
auto *sp = model_proto->add_pieces();
|
||||
sp->set_piece(w.first);
|
||||
sp->set_score(w.second);
|
||||
CHECK_PIECE(sp->piece());
|
||||
}
|
||||
CHECK_EQ_OR_RETURN(fid, final_pieces_.size());
|
||||
|
||||
*(model_proto->mutable_trainer_spec()) = trainer_spec_;
|
||||
*(model_proto->mutable_normalizer_spec()) = normalizer_spec_;
|
||||
@ -314,13 +322,13 @@ util::Status TrainerInterface::Serialize(ModelProto *model_proto) const {
|
||||
trainer_spec_.model_type() == TrainerSpec::CHAR) {
|
||||
CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(), model_proto->pieces_size());
|
||||
CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(),
|
||||
static_cast<int>(dup.size()));
|
||||
static_cast<int32>(dup.size()));
|
||||
model_proto->mutable_trainer_spec()->set_vocab_size(
|
||||
model_proto->pieces_size());
|
||||
} else {
|
||||
CHECK_OR_RETURN(trainer_spec_.vocab_size() == model_proto->pieces_size() &&
|
||||
trainer_spec_.vocab_size() == static_cast<int>(dup.size()))
|
||||
<< "Use --hard_vocab_limit=false to make the vocab size `soft limit`.";
|
||||
CHECK_EQ_OR_RETURN(trainer_spec_.vocab_size(), model_proto->pieces_size());
|
||||
CHECK_EQ_OR_RETURN(trainer_spec_.vocab_size(),
|
||||
static_cast<int32>(dup.size()));
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
@ -361,40 +369,64 @@ util::Status TrainerInterface::Save() const {
|
||||
|
||||
util::Status TrainerInterface::InitMetaPieces() {
|
||||
CHECK_OR_RETURN(meta_pieces_.empty());
|
||||
|
||||
std::vector<std::pair<int, std::string>> ids;
|
||||
if (trainer_spec_.unk_id() >= 0)
|
||||
ids.emplace_back(trainer_spec_.unk_id(), kUNK);
|
||||
if (trainer_spec_.bos_id() >= 0)
|
||||
ids.emplace_back(trainer_spec_.bos_id(), kBOS);
|
||||
if (trainer_spec_.eos_id() >= 0)
|
||||
ids.emplace_back(trainer_spec_.eos_id(), kEOS);
|
||||
if (trainer_spec_.pad_id() >= 0)
|
||||
ids.emplace_back(trainer_spec_.pad_id(), kPAD);
|
||||
|
||||
std::sort(ids.begin(), ids.end());
|
||||
|
||||
int prev_id = -1;
|
||||
bool has_unk = false;
|
||||
for (const auto &p : ids) {
|
||||
CHECK_EQ_OR_RETURN(prev_id + 1, p.first)
|
||||
<< "ID for `" << p.second << "` must be " << prev_id + 1;
|
||||
prev_id = p.first;
|
||||
CHECK_EQ_OR_RETURN(static_cast<int>(meta_pieces_.size()), p.first);
|
||||
if (p.second == kUNK) has_unk = true;
|
||||
meta_pieces_.emplace_back(
|
||||
p.second, (p.second == kUNK ? ModelProto::SentencePiece::UNKNOWN
|
||||
: ModelProto::SentencePiece::CONTROL));
|
||||
}
|
||||
|
||||
auto insert_id = [&has_unk, this](int id, const std::string &w) -> bool {
|
||||
if (id < 0) return true;
|
||||
if (id >= trainer_spec_.vocab_size() ||
|
||||
meta_pieces_.find(id) != meta_pieces_.end() || (has_unk && w == kUNK))
|
||||
return false;
|
||||
if (w == kUNK) has_unk = true;
|
||||
meta_pieces_[id] =
|
||||
std::make_pair(w, w == kUNK ? ModelProto::SentencePiece::UNKNOWN
|
||||
: ModelProto::SentencePiece::CONTROL);
|
||||
return true;
|
||||
};
|
||||
|
||||
CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), kUNK));
|
||||
CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), kBOS));
|
||||
CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), kEOS));
|
||||
CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), kPAD));
|
||||
|
||||
CHECK_OR_RETURN(has_unk) << kUNK << " must be defined.";
|
||||
|
||||
std::set<std::string> dup;
|
||||
|
||||
int id = 0;
|
||||
auto insert_meta_symbol = [&id, &dup, this](
|
||||
const std::string &w,
|
||||
ModelProto::SentencePiece::Type type) -> bool {
|
||||
if (!dup.insert(w).second) {
|
||||
LOG(ERROR) << w << " is already defined.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (w == kUNK) {
|
||||
LOG(ERROR) << "<unk> must not be defined with --control_symbols and "
|
||||
"--user_defined_symbols.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (w == kBOS && trainer_spec_.bos_id() >= 0) {
|
||||
meta_pieces_[trainer_spec_.bos_id()].second = type;
|
||||
} else if (w == kEOS && trainer_spec_.eos_id() >= 0) {
|
||||
meta_pieces_[trainer_spec_.eos_id()].second = type;
|
||||
} else if (w == kPAD && trainer_spec_.pad_id() >= 0) {
|
||||
meta_pieces_[trainer_spec_.pad_id()].second = type;
|
||||
} else {
|
||||
while (meta_pieces_.find(id) != meta_pieces_.end()) ++id;
|
||||
meta_pieces_[id] = std::make_pair(w, type);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
for (const auto &w : trainer_spec_.control_symbols()) {
|
||||
meta_pieces_.emplace_back(w, ModelProto::SentencePiece::CONTROL);
|
||||
CHECK_OR_RETURN(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL));
|
||||
}
|
||||
|
||||
for (const auto &w : trainer_spec_.user_defined_symbols()) {
|
||||
meta_pieces_.emplace_back(w, ModelProto::SentencePiece::USER_DEFINED);
|
||||
CHECK_OR_RETURN(
|
||||
insert_meta_symbol(w, ModelProto::SentencePiece::USER_DEFINED));
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
|
@ -111,8 +111,8 @@ class TrainerInterface {
|
||||
NormalizerSpec normalizer_spec_;
|
||||
|
||||
// Reserved control pieces. e.g., <unk>, <s>, </s>.
|
||||
// The index corresponds to vocab id.
|
||||
std::vector<std::pair<std::string, ModelProto::SentencePiece::Type>>
|
||||
// key is vocab id.
|
||||
std::map<int, std::pair<std::string, ModelProto::SentencePiece::Type>>
|
||||
meta_pieces_;
|
||||
|
||||
// Detect errors on initialization.
|
||||
|
@ -82,10 +82,12 @@ TEST(TrainerInterfaceTest, IsValidSentencePieceTest) {
|
||||
}
|
||||
|
||||
TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
TrainerSpec trainer_spec;
|
||||
TrainerSpec base_trainer_spec;
|
||||
NormalizerSpec normalizer_spec;
|
||||
trainer_spec.set_model_prefix("model");
|
||||
trainer_spec.add_input("input");
|
||||
base_trainer_spec.set_model_prefix("model");
|
||||
base_trainer_spec.add_input("input");
|
||||
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
|
||||
// Check default values.
|
||||
EXPECT_EQ(0, trainer_spec.unk_id());
|
||||
@ -94,6 +96,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
EXPECT_EQ(-1, trainer_spec.pad_id());
|
||||
|
||||
{
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(1);
|
||||
trainer_spec.set_eos_id(2);
|
||||
@ -108,6 +111,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
}
|
||||
|
||||
{
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(3);
|
||||
trainer_spec.set_eos_id(2);
|
||||
@ -122,6 +126,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
}
|
||||
|
||||
{
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(-1);
|
||||
trainer_spec.set_eos_id(1);
|
||||
@ -134,6 +139,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
}
|
||||
|
||||
{
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(-1);
|
||||
trainer_spec.set_eos_id(-1);
|
||||
@ -145,6 +151,7 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
}
|
||||
|
||||
{
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(1);
|
||||
trainer_spec.set_eos_id(2);
|
||||
@ -167,21 +174,96 @@ TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest) {
|
||||
}
|
||||
|
||||
{
|
||||
// ID is not contiguous.
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(-1);
|
||||
trainer_spec.set_eos_id(2);
|
||||
TrainerInterface trainer(trainer_spec, normalizer_spec);
|
||||
EXPECT_NOT_OK(trainer.status());
|
||||
EXPECT_OK(trainer.status());
|
||||
}
|
||||
|
||||
{
|
||||
// UNK is not defined.
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(-1);
|
||||
trainer_spec.set_bos_id(0);
|
||||
trainer_spec.set_eos_id(1);
|
||||
TrainerInterface trainer(trainer_spec, normalizer_spec);
|
||||
EXPECT_NOT_OK(trainer.status());
|
||||
EXPECT_FALSE(trainer.status().ok());
|
||||
}
|
||||
|
||||
{
|
||||
// UNK is out-of-range.
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(640000);
|
||||
trainer_spec.set_bos_id(0);
|
||||
trainer_spec.set_eos_id(1);
|
||||
TrainerInterface trainer(trainer_spec, normalizer_spec);
|
||||
EXPECT_FALSE(trainer.status().ok());
|
||||
}
|
||||
|
||||
{
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_vocab_size(32000);
|
||||
trainer_spec.set_unk_id(32000 - 1);
|
||||
trainer_spec.set_bos_id(32000 - 100);
|
||||
trainer_spec.set_eos_id(32000 - 200);
|
||||
TrainerInterface trainer(trainer_spec, normalizer_spec);
|
||||
EXPECT_OK(trainer.status());
|
||||
}
|
||||
|
||||
{
|
||||
// Cannot assign <unk> as control symbol.
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(1);
|
||||
trainer_spec.set_eos_id(2);
|
||||
trainer_spec.add_control_symbols("<unk>");
|
||||
TrainerInterface trainer(trainer_spec, normalizer_spec);
|
||||
EXPECT_FALSE(trainer.status().ok());
|
||||
}
|
||||
|
||||
{
|
||||
// Dup.
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.add_control_symbols("<foo>");
|
||||
trainer_spec.add_control_symbols("<foo>");
|
||||
TrainerInterface trainer(trainer_spec, normalizer_spec);
|
||||
EXPECT_FALSE(trainer.status().ok());
|
||||
}
|
||||
|
||||
{
|
||||
auto trainer_spec = base_trainer_spec;
|
||||
trainer_spec.set_unk_id(0);
|
||||
trainer_spec.set_bos_id(10);
|
||||
trainer_spec.set_eos_id(20);
|
||||
trainer_spec.set_pad_id(30);
|
||||
|
||||
// <s>, <pad> are treated as USER_DEFIEND,
|
||||
// </s> is CONTROL.
|
||||
trainer_spec.add_user_defined_symbols("<s>");
|
||||
trainer_spec.add_user_defined_symbols("<pad>");
|
||||
trainer_spec.add_user_defined_symbols("foo");
|
||||
TrainerInterface trainer(trainer_spec, normalizer_spec);
|
||||
EXPECT_OK(trainer.status());
|
||||
|
||||
EXPECT_EQ(5, trainer.meta_pieces_.size());
|
||||
EXPECT_EQ("<unk>", trainer.meta_pieces_[0].first);
|
||||
EXPECT_EQ("<s>", trainer.meta_pieces_[10].first);
|
||||
EXPECT_EQ("</s>", trainer.meta_pieces_[20].first);
|
||||
EXPECT_EQ("<pad>", trainer.meta_pieces_[30].first);
|
||||
EXPECT_EQ("foo", trainer.meta_pieces_[1].first);
|
||||
|
||||
EXPECT_EQ(ModelProto::SentencePiece::UNKNOWN,
|
||||
trainer.meta_pieces_[0].second);
|
||||
EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED,
|
||||
trainer.meta_pieces_[10].second);
|
||||
EXPECT_EQ(ModelProto::SentencePiece::CONTROL,
|
||||
trainer.meta_pieces_[20].second);
|
||||
EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED,
|
||||
trainer.meta_pieces_[30].second);
|
||||
EXPECT_EQ(ModelProto::SentencePiece::USER_DEFINED,
|
||||
trainer.meta_pieces_[1].second);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,6 @@
|
||||
// limitations under the License.!
|
||||
|
||||
#include "unigram_model_trainer.h"
|
||||
#include "builder.h"
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "testharness.h"
|
||||
@ -39,7 +37,6 @@ TEST(UnigramTrainerTest, EndToEndTest) {
|
||||
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("identity");
|
||||
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
|
||||
|
||||
constexpr int kVocabSize = 8000;
|
||||
trainer_spec.set_vocab_size(kVocabSize);
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
#include "word_model_trainer.h"
|
||||
|
||||
#include "builder.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
@ -46,7 +45,6 @@ std::string RunTrainer(const std::vector<std::string> &input, int size) {
|
||||
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("identity");
|
||||
EXPECT_OK(normalizer::Builder::PopulateNormalizerSpec(&normalizer_spec));
|
||||
normalizer_spec.set_add_dummy_prefix(true);
|
||||
|
||||
Trainer trainer(trainer_spec, normalizer_spec);
|
||||
|
Loading…
Reference in New Issue
Block a user