returns unicode characetr offsets in normalize method

This commit is contained in:
Taku Kudo 2024-01-22 07:19:04 +00:00
parent 6b468a0e01
commit 41c4b7f080
6 changed files with 93 additions and 22 deletions

View File

@ -368,6 +368,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%ignore sentencepiece::SentencePieceTrainer::PieceProcecssor;
%ignore sentencepiece::SentencePieceTrainer::SetPretokenizerForTraining;
%ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining;
%ignore sentencepiece::ConvertToUnicodeAlignment;
%ignore sentencepiece::SentencePieceNormalizer::Load;
%ignore sentencepiece::SentencePieceNormalizer::Normalize;
@ -1838,8 +1839,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
}
}
// Types for normalized string and offset
%typemap(out) std::pair<std::string, std::vector<size_t>> {
PyObject *input_type = resultobj;
if (PyInputString::IsUnicode(input_type)) {
sentencepiece::ConvertToUnicodeAlignment(arg2, $1.first, &$1.second);
}
PyObject *obj = PyList_New($1.second.size());
for (size_t i = 0; i < $1.second.size(); ++i) {
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>($1.second[i])));

View File

@ -8633,6 +8633,9 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor__NormalizeWithOffsets(PyObject
}
{
PyObject *input_type = resultobj;
if (PyInputString::IsUnicode(input_type)) {
sentencepiece::ConvertToUnicodeAlignment(arg2, (&result)->first, &(&result)->second);
}
PyObject *obj = PyList_New((&result)->second.size());
for (size_t i = 0; i < (&result)->second.size(); ++i) {
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>((&result)->second[i])));
@ -9541,6 +9544,9 @@ SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__NormalizeWithOffsets(PyObjec
}
{
PyObject *input_type = resultobj;
if (PyInputString::IsUnicode(input_type)) {
sentencepiece::ConvertToUnicodeAlignment(arg2, (&result)->first, &(&result)->second);
}
PyObject *obj = PyList_New((&result)->second.size());
for (size_t i = 0; i < (&result)->second.size(); ++i) {
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>((&result)->second[i])));

View File

@ -804,6 +804,10 @@ class TestSentencepieceProcessor(unittest.TestCase):
x = sp.Normalize('ABC', with_offsets=True)
self.assertEqual('▁KADOKAWAABC', x[0])
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
x = sp.Normalize('ABC'.encode('utf8'), with_offsets=True)
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0])
self.assertEqual(
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1]
)
@ -815,14 +819,23 @@ class TestSentencepieceProcessor(unittest.TestCase):
['▁KADOKAWAABC', '▁平成'], sp.Normalize(['ABC', ''])
)
x = sp.Normalize(['ABC', ''], with_offsets=True)
x = sp.Normalize(
['ABC'.encode('utf8'), ''.encode('utf8')],
with_offsets=True,
)
self.assertEqual(len(x), 2)
self.assertEqual('▁KADOKAWAABC', x[0][0])
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0][0])
self.assertEqual(
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1]
)
x = sp.Normalize(['ABC', ''], with_offsets=True)
self.assertEqual(len(x), 2)
self.assertEqual('▁KADOKAWAABC', x[0][0])
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
self.assertEqual('▁平成', x[1][0])
self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0, 3], x[1][1])
self.assertEqual([0, 0, 0, 1], x[1][1])
def test_normalizer(self):
sp = spm.SentencePieceNormalizer(
@ -832,9 +845,13 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual('KADOKAWAABC', sp.normalize('ABC'))
self.assertEqual('KADOKAWAABC', sp.Normalize('ABC'))
x = sp.Normalize('ABC'.encode('utf8'), with_offsets=True)
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
x = sp.Normalize('ABC', with_offsets=True)
self.assertEqual('KADOKAWAABC', x[0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
self.assertEqual(
['KADOKAWAABC', '平成'], sp.normalize(['ABC', ''])
@ -843,13 +860,20 @@ class TestSentencepieceProcessor(unittest.TestCase):
['KADOKAWAABC', '平成'], sp.Normalize(['ABC', ''])
)
x = sp.Normalize(
['ABC'.encode('utf8'), ''.encode('utf8')],
with_offsets=True,
)
self.assertEqual(len(x), 2)
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0][0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
x = sp.Normalize(['ABC', ''], with_offsets=True)
self.assertEqual(len(x), 2)
self.assertEqual('KADOKAWAABC', x[0][0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
self.assertEqual('平成', x[1][0])
self.assertEqual([0, 0, 0, 0, 0, 0, 3], x[1][1])
self.assertEqual([0, 0, 1], x[1][1])
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model'),

View File

@ -366,4 +366,35 @@ std::string SentencePieceNormalizer::serialized_model_proto() const {
return model_proto_ ? model_proto_->SerializeAsString() : "";
}
void ConvertToUnicodeAlignment(absl::string_view orig, absl::string_view norm,
std::vector<size_t> *norm_to_orig) {
auto utf8_to_unicode_offsets = [](absl::string_view str) {
std::vector<int> utf8_to_unicode(str.size() + 1, 0);
size_t prev = 0;
int ulen = 0;
while (!str.empty()) {
const size_t mblen =
std::max<int>(1, string_util::OneCharLen(str.data()));
for (int i = prev; i < prev + mblen; ++i) {
utf8_to_unicode[i] = ulen;
}
++ulen;
prev += mblen;
str.remove_prefix(mblen);
}
utf8_to_unicode[prev] = ulen;
return utf8_to_unicode;
};
const auto orig_offsets = utf8_to_unicode_offsets(orig);
const auto norm_offsets = utf8_to_unicode_offsets(norm);
if (orig_offsets.empty() || norm_offsets.empty()) return;
std::vector<size_t> result(norm_offsets.back() + 1, 0);
for (int i = 0; i < norm_to_orig->size(); ++i) {
result[norm_offsets[i]] = orig_offsets[(*norm_to_orig)[i]];
}
*norm_to_orig = std::move(result);
}
} // namespace sentencepiece

View File

@ -197,6 +197,10 @@ class SentencePieceNormalizer {
std::unique_ptr<ModelProto> model_proto_;
};
// Converts the utf8 byte spans into Unicode char span.
void ConvertToUnicodeAlignment(absl::string_view orig, absl::string_view norm,
std::vector<size_t> *norm_to_orig);
} // namespace sentencepiece
#endif // SENTENCEPIECE_TRAINER_H_

View File

@ -376,32 +376,33 @@ TEST(SentencePieceTrainerTest, NormalizationTest) {
trainer_spec.set_vocab_size(1000);
ASSERT_TRUE(SentencePieceTrainer::Train(trainer_spec).ok());
constexpr absl::string_view kInput = " ABC ";
{
SentencePieceProcessor sp;
EXPECT_OK(sp.Load(model_file));
EXPECT_EQ(sp.Normalize(" ABC "), "▁KADOKAWA▁ABC");
EXPECT_EQ(sp.Normalize(kInput), "▁KADOKAWA▁ABC");
std::string normalized;
std::vector<size_t> offsets;
EXPECT_OK(sp.Normalize(" ABC ", &normalized, &offsets));
EXPECT_OK(sp.Normalize(kInput, &normalized, &offsets));
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
24, 24, 24, 27, 28, 29, 30}));
}
ConvertToUnicodeAlignment(kInput, normalized, &offsets);
EXPECT_EQ(offsets, std::vector<size_t>(
{0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14}));
{
SentencePieceNormalizer sp;
EXPECT_OK(sp.Load(model_file));
EXPECT_EQ(sp.Normalize(" ABC "), "▁KADOKAWA▁ABC");
EXPECT_OK(sp.Normalize("㍻元年", &normalized, &offsets));
EXPECT_EQ(normalized, "▁平成元年");
ConvertToUnicodeAlignment(kInput, normalized, &offsets);
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 1, 2, 3}));
std::string normalized;
std::vector<size_t> offsets;
EXPECT_OK(sp.Normalize(" ABC ", &normalized, &offsets));
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
24, 24, 24, 27, 28, 29, 30}));
EXPECT_OK(sp.Normalize("ガイダンス", &normalized, &offsets));
EXPECT_EQ(normalized, "▁ガイダンス");
ConvertToUnicodeAlignment(kInput, normalized, &offsets);
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 2, 3, 5, 6, 7}));
}
auto set_normalization_only = [](SentencePieceNormalizer *normalizer) {
@ -417,7 +418,7 @@ TEST(SentencePieceTrainerTest, NormalizationTest) {
SentencePieceNormalizer sp;
EXPECT_OK(sp.Load(model_file));
set_normalization_only(&sp);
EXPECT_EQ(sp.Normalize(" ABC "), "KADOKAWA ABC ");
EXPECT_EQ(sp.Normalize(kInput), "KADOKAWA ABC ");
}
{