stop normalization for user_defined_symbols

This commit is contained in:
Taku Kudo 2018-11-08 17:26:14 +09:00
parent bcb1363871
commit c950219a07
14 changed files with 184 additions and 130 deletions

View File

@ -28,7 +28,7 @@ namespace bpe {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
InitializePieces(true /* use prefix matcher */);
InitializePieces();
}
Model::~Model() {}

View File

@ -20,7 +20,7 @@ namespace character {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
InitializePieces(true /* use prefix matcher */);
InitializePieces();
}
Model::~Model() {}

View File

@ -20,56 +20,6 @@
namespace sentencepiece {
PrefixMatcher::PrefixMatcher(const std::set<absl::string_view> &dic) {
if (dic.empty()) return;
std::vector<const char *> key;
key.reserve(dic.size());
for (const auto &it : dic) key.push_back(it.data());
trie_ = port::MakeUnique<Darts::DoubleArray>();
CHECK_EQ(0, trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
nullptr));
}
int PrefixMatcher::PrefixMatch(absl::string_view w, bool *found) const {
if (trie_ == nullptr) {
if (found) *found = false;
return std::min<int>(w.size(), string_util::OneCharLen(w.data()));
}
constexpr int kResultSize = 64;
Darts::DoubleArray::result_pair_type trie_results[kResultSize];
const int num_nodes =
trie_->commonPrefixSearch(w.data(), trie_results, kResultSize, w.size());
if (found) *found = (num_nodes > 0);
if (num_nodes == 0) {
return std::min<int>(w.size(), string_util::OneCharLen(w.data()));
}
int mblen = 0;
for (int i = 0; i < num_nodes; ++i) {
mblen = std::max<int>(trie_results[i].length, mblen);
}
return mblen;
}
std::string PrefixMatcher::GlobalReplace(absl::string_view w,
absl::string_view out) const {
std::string result;
while (!w.empty()) {
bool found = false;
const int mblen = PrefixMatch(w, &found);
if (found) {
result.append(out.data(), out.size());
} else {
result.append(w.data(), mblen);
}
w.remove_prefix(mblen);
}
return result;
}
const char *ModelInterface::kUNK() { return "<unk>"; }
const char *ModelInterface::kBOS() { return "<s>"; }
const char *ModelInterface::kEOS() { return "</s>"; }
@ -91,7 +41,7 @@ int ModelInterface::PieceToId(absl::string_view piece) const {
return unk_id_;
}
void ModelInterface::InitializePieces(bool use_prefix_matcher) {
void ModelInterface::InitializePieces() {
pieces_.clear();
reserved_id_map_.clear();
unk_id_ = -1;
@ -115,8 +65,7 @@ void ModelInterface::InitializePieces(bool use_prefix_matcher) {
return;
}
if (use_prefix_matcher &&
sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
if (sp.type() == ModelProto::SentencePiece::USER_DEFINED) {
user_defined_symbols.insert(sp.piece());
}
@ -134,9 +83,7 @@ void ModelInterface::InitializePieces(bool use_prefix_matcher) {
return;
}
if (use_prefix_matcher) {
matcher_ = port::MakeUnique<PrefixMatcher>(user_defined_symbols);
}
matcher_ = port::MakeUnique<normalizer::PrefixMatcher>(user_defined_symbols);
}
std::vector<absl::string_view> SplitIntoWords(absl::string_view text) {

View File

@ -23,6 +23,7 @@
#include <vector>
#include "common.h"
#include "normalizer.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "third_party/absl/strings/string_view.h"
@ -39,26 +40,6 @@ using NBestEncodeResult = std::vector<std::pair<EncodeResult, float>>;
class ModelProto;
// Given a list of strings, finds the longest string which is a
// prefix of a query.
class PrefixMatcher {
public:
// Initializes the PrefixMatcher with `dic`.
explicit PrefixMatcher(const std::set<absl::string_view> &dic);
// Finds the longest string in dic, which is a prefix of `w`.
// Returns the UTF8 byte length of matched string.
// `found` is set if a prefix match exists.
// If no entry is found, consumes one Unicode character.
int PrefixMatch(absl::string_view w, bool *found = nullptr) const;
// Replaces entries in `w` with `out`.
std::string GlobalReplace(absl::string_view w, absl::string_view out) const;
private:
std::unique_ptr<Darts::DoubleArray> trie_;
};
// Underlying model interface.
// Given a normalized string, returns a sequence of sentence pieces with ids.
class ModelInterface {
@ -83,6 +64,10 @@ class ModelInterface {
virtual const ModelProto &model_proto() const { return *model_proto_; }
virtual const normalizer::PrefixMatcher *prefix_matcher() const {
return matcher_.get();
}
// Given a normalized string, returns a sequence of sentence pieces with ids.
// The concatenation of pieces must be the same as `normalized`.
virtual EncodeResult Encode(absl::string_view normalized) const = 0;
@ -146,7 +131,7 @@ class ModelInterface {
}
protected:
void InitializePieces(bool use_prefix_matcher);
void InitializePieces();
// Non-virtual (inlined) implementation for faster execution.
inline float GetScoreInlined(int id) const {
@ -176,7 +161,7 @@ class ModelInterface {
const ModelProto *model_proto_ = nullptr;
// PrefixMatcher for user defined symbols.
std::unique_ptr<PrefixMatcher> matcher_;
std::unique_ptr<normalizer::PrefixMatcher> matcher_;
// piece -> id map for normal pieces
PieceToIdMap pieces_;

View File

@ -265,50 +265,5 @@ TEST(ModelInterfaceTest, SplitIntoWordsTest) {
}
}
TEST(ModelInterfaceTest, PrefixMatcherTest) {
const PrefixMatcher matcher({"abc", "ab", "xy", "京都"});
bool found;
EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
EXPECT_FALSE(found);
EXPECT_EQ(3, matcher.PrefixMatch("abcd", &found));
EXPECT_TRUE(found);
EXPECT_EQ(2, matcher.PrefixMatch("abxy", &found));
EXPECT_TRUE(found);
EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
EXPECT_FALSE(found);
EXPECT_EQ(2, matcher.PrefixMatch("xyz", &found));
EXPECT_TRUE(found);
EXPECT_EQ(6, matcher.PrefixMatch("京都大学", &found));
EXPECT_TRUE(found);
EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
EXPECT_FALSE(found);
EXPECT_EQ("", matcher.GlobalReplace("", ""));
EXPECT_EQ("", matcher.GlobalReplace("abc", ""));
EXPECT_EQ("--de-pqr", matcher.GlobalReplace("xyabcdeabpqr", "-"));
}
TEST(ModelInterfaceTest, PrefixMatcherWithEmptyTest) {
const PrefixMatcher matcher({});
bool found;
EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("abcd", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("abxy", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("xyz", &found));
EXPECT_FALSE(found);
EXPECT_EQ(3, matcher.PrefixMatch("京都大学", &found));
EXPECT_FALSE(found);
EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
EXPECT_FALSE(found);
EXPECT_EQ("", matcher.GlobalReplace("", ""));
EXPECT_EQ("abc", matcher.GlobalReplace("abc", ""));
}
} // namespace
} // namespace sentencepiece

View File

@ -174,6 +174,12 @@ std::pair<absl::string_view, int> Normalizer::NormalizePrefix(
if (input.empty()) return result;
if (matcher_ != nullptr) {
bool found = false;
const int mblen = matcher_->PrefixMatch(input, &found);
if (found) return std::make_pair(input.substr(0, mblen), mblen);
}
size_t longest_length = 0;
int longest_value = 0;
@ -254,5 +260,56 @@ util::Status Normalizer::DecodePrecompiledCharsMap(
return util::OkStatus();
}
PrefixMatcher::PrefixMatcher(const std::set<absl::string_view> &dic) {
if (dic.empty()) return;
std::vector<const char *> key;
key.reserve(dic.size());
for (const auto &it : dic) key.push_back(it.data());
trie_ = port::MakeUnique<Darts::DoubleArray>();
CHECK_EQ(0, trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
nullptr));
}
int PrefixMatcher::PrefixMatch(absl::string_view w, bool *found) const {
if (trie_ == nullptr) {
if (found) *found = false;
return std::min<int>(w.size(), string_util::OneCharLen(w.data()));
}
constexpr int kResultSize = 64;
Darts::DoubleArray::result_pair_type trie_results[kResultSize];
const int num_nodes =
trie_->commonPrefixSearch(w.data(), trie_results, kResultSize, w.size());
if (found) *found = (num_nodes > 0);
if (num_nodes == 0) {
return std::min<int>(w.size(), string_util::OneCharLen(w.data()));
}
int mblen = 0;
for (int i = 0; i < num_nodes; ++i) {
mblen = std::max<int>(trie_results[i].length, mblen);
}
return mblen;
}
std::string PrefixMatcher::GlobalReplace(absl::string_view w,
absl::string_view out) const {
std::string result;
while (!w.empty()) {
bool found = false;
const int mblen = PrefixMatch(w, &found);
if (found) {
result.append(out.data(), out.size());
} else {
result.append(w.data(), mblen);
}
w.remove_prefix(mblen);
}
return result;
}
} // namespace normalizer
} // namespace sentencepiece

View File

@ -29,6 +29,26 @@
namespace sentencepiece {
namespace normalizer {
// Given a list of strings, finds the longest string which is a
// prefix of a query.
class PrefixMatcher {
public:
// Initializes the PrefixMatcher with `dic`.
explicit PrefixMatcher(const std::set<absl::string_view> &dic);
// Finds the longest string in dic, which is a prefix of `w`.
// Returns the UTF8 byte length of matched string.
// `found` is set if a prefix match exists.
// If no entry is found, consumes one Unicode character.
int PrefixMatch(absl::string_view w, bool *found = nullptr) const;
// Replaces entries in `w` with `out`.
std::string GlobalReplace(absl::string_view w, absl::string_view out) const;
private:
std::unique_ptr<Darts::DoubleArray> trie_;
};
// Normalizer implements a simple text normalizer with
// user-defined string-to-string rules and leftmost longest
// matching. The rules of Normalizer are built with
@ -46,6 +66,10 @@ class Normalizer {
explicit Normalizer(const NormalizerSpec &spec);
virtual ~Normalizer();
virtual void SetPrefixMatcher(const PrefixMatcher *matcher) {
matcher_ = matcher;
}
// Returns Status.
// Normalizes function is valid only when status is OK.
virtual util::Status status() const { return status_; }
@ -59,7 +83,8 @@ class Normalizer {
// - Adds a prefix space.
// - Replaces a space with a meta symbol.
// - Removing heading, tailing and other redundant spaces.
virtual util::Status Normalize(absl::string_view input, std::string *normalized,
virtual util::Status Normalize(absl::string_view input,
std::string *normalized,
std::vector<size_t> *norm_to_orig) const;
// Returns a normalized string without alignments.
@ -83,7 +108,8 @@ class Normalizer {
// output.append(p.first.data(), p.first.size());
// input.remove_prefix(p.second);
// }
std::pair<absl::string_view, int> NormalizePrefix(absl::string_view input) const;
std::pair<absl::string_view, int> NormalizePrefix(
absl::string_view input) const;
// Encodes trie_blob and normalized string and return compiled blob.
static std::string EncodePrecompiledCharsMap(absl::string_view trie_blob,
@ -108,6 +134,9 @@ class Normalizer {
// Spec for normalization.
const NormalizerSpec *spec_;
// Prefix matcher;
const PrefixMatcher *matcher_ = nullptr;
// Normalizer's status.
util::Status status_;
};

View File

@ -361,5 +361,51 @@ TEST(NormalizerTest, StatusTest) {
EXPECT_TRUE(normalizer.status().ok());
}
}
TEST(NormalizerTest, PrefixMatcherTest) {
const PrefixMatcher matcher({"abc", "ab", "xy", "京都"});
bool found;
EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
EXPECT_FALSE(found);
EXPECT_EQ(3, matcher.PrefixMatch("abcd", &found));
EXPECT_TRUE(found);
EXPECT_EQ(2, matcher.PrefixMatch("abxy", &found));
EXPECT_TRUE(found);
EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
EXPECT_FALSE(found);
EXPECT_EQ(2, matcher.PrefixMatch("xyz", &found));
EXPECT_TRUE(found);
EXPECT_EQ(6, matcher.PrefixMatch("京都大学", &found));
EXPECT_TRUE(found);
EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
EXPECT_FALSE(found);
EXPECT_EQ("", matcher.GlobalReplace("", ""));
EXPECT_EQ("", matcher.GlobalReplace("abc", ""));
EXPECT_EQ("--de-pqr", matcher.GlobalReplace("xyabcdeabpqr", "-"));
}
TEST(NormalizerTest, PrefixMatcherWithEmptyTest) {
const PrefixMatcher matcher({});
bool found;
EXPECT_EQ(1, matcher.PrefixMatch("test", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("abcd", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("abxy", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("x", &found));
EXPECT_FALSE(found);
EXPECT_EQ(1, matcher.PrefixMatch("xyz", &found));
EXPECT_FALSE(found);
EXPECT_EQ(3, matcher.PrefixMatch("京都大学", &found));
EXPECT_FALSE(found);
EXPECT_EQ(3, matcher.PrefixMatch("東京大学", &found));
EXPECT_FALSE(found);
EXPECT_EQ("", matcher.GlobalReplace("", ""));
EXPECT_EQ("abc", matcher.GlobalReplace("abc", ""));
}
} // namespace normalizer
} // namespace sentencepiece

View File

@ -81,6 +81,9 @@ util::Status SentencePieceProcessor::Load(
normalizer_ =
port::MakeUnique<normalizer::Normalizer>(model_proto_->normalizer_spec());
// Escapes user-defined-symbols in normalizer.
normalizer_->SetPrefixMatcher(model_->prefix_matcher());
RETURN_IF_ERROR(status());
// Running self-testing.

View File

@ -1002,6 +1002,38 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {
}
}
TEST(SentencePieceProcessorTest, SkipNormalizationTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::USER_DEFINED);
sp2->set_piece("<USER>");
AddPiece(&model_proto, "a", 0.0);
AddPiece(&model_proto, "b", 0.3);
AddPiece(&model_proto, "c", 0.2);
AddPiece(&model_proto, "u", 0.2);
AddPiece(&model_proto, "s", 0.2);
AddPiece(&model_proto, "e", 0.2);
AddPiece(&model_proto, "r", 0.2);
*(model_proto.mutable_normalizer_spec()) =
SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc_cf");
SentencePieceProcessor sp;
sp.Load(model_proto);
std::vector<std::string> pieces;
EXPECT_OK(sp.Encode("AB<USER>C<uSEr>", &pieces));
for (const auto &sp : pieces) LOG(INFO) << sp;
EXPECT_EQ(std::vector<std::string>(
{WS, "a", "b", "<USER>", "c", "<", "u", "s", "e", "r", ">"}),
pieces);
}
TEST(SentencePieceProcessorTest, ExtraOptionsUndefinedTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();

View File

@ -161,7 +161,7 @@ util::Status TrainerInterface::LoadSentences() {
std::set<absl::string_view> meta_pieces_set;
for (const auto &it : meta_pieces_) meta_pieces_set.insert(it.second.first);
const PrefixMatcher meta_pieces_matcher(meta_pieces_set);
const normalizer::PrefixMatcher meta_pieces_matcher(meta_pieces_set);
random::ReservoirSampler<std::string> sampler(
trainer_spec_.self_test_sample_size());

View File

@ -475,7 +475,7 @@ void Model::BuildTrie(std::vector<std::pair<absl::string_view, int>> *pieces) {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
InitializePieces(false /* enable prefix matcher */);
InitializePieces();
min_score_ = FLT_MAX;
max_score_ = FLT_MIN;

View File

@ -20,7 +20,7 @@ namespace word {
Model::Model(const ModelProto &model_proto) {
model_proto_ = &model_proto;
InitializePieces(false /* enable string matcher */);
InitializePieces();
}
Model::~Model() {}