mirror of
https://github.com/google/sentencepiece.git
synced 2024-08-15 22:00:43 +03:00
returns unicode characetr offsets in normalize method
This commit is contained in:
parent
6b468a0e01
commit
41c4b7f080
@ -368,6 +368,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
|
||||
%ignore sentencepiece::SentencePieceTrainer::PieceProcecssor;
|
||||
%ignore sentencepiece::SentencePieceTrainer::SetPretokenizerForTraining;
|
||||
%ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining;
|
||||
%ignore sentencepiece::ConvertToUnicodeAlignment;
|
||||
|
||||
%ignore sentencepiece::SentencePieceNormalizer::Load;
|
||||
%ignore sentencepiece::SentencePieceNormalizer::Normalize;
|
||||
@ -1838,8 +1839,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
|
||||
}
|
||||
}
|
||||
|
||||
// Types for normalized string and offset
|
||||
%typemap(out) std::pair<std::string, std::vector<size_t>> {
|
||||
PyObject *input_type = resultobj;
|
||||
if (PyInputString::IsUnicode(input_type)) {
|
||||
sentencepiece::ConvertToUnicodeAlignment(arg2, $1.first, &$1.second);
|
||||
}
|
||||
PyObject *obj = PyList_New($1.second.size());
|
||||
for (size_t i = 0; i < $1.second.size(); ++i) {
|
||||
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>($1.second[i])));
|
||||
|
@ -8633,6 +8633,9 @@ SWIGINTERN PyObject *_wrap_SentencePieceProcessor__NormalizeWithOffsets(PyObject
|
||||
}
|
||||
{
|
||||
PyObject *input_type = resultobj;
|
||||
if (PyInputString::IsUnicode(input_type)) {
|
||||
sentencepiece::ConvertToUnicodeAlignment(arg2, (&result)->first, &(&result)->second);
|
||||
}
|
||||
PyObject *obj = PyList_New((&result)->second.size());
|
||||
for (size_t i = 0; i < (&result)->second.size(); ++i) {
|
||||
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>((&result)->second[i])));
|
||||
@ -9541,6 +9544,9 @@ SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__NormalizeWithOffsets(PyObjec
|
||||
}
|
||||
{
|
||||
PyObject *input_type = resultobj;
|
||||
if (PyInputString::IsUnicode(input_type)) {
|
||||
sentencepiece::ConvertToUnicodeAlignment(arg2, (&result)->first, &(&result)->second);
|
||||
}
|
||||
PyObject *obj = PyList_New((&result)->second.size());
|
||||
for (size_t i = 0; i < (&result)->second.size(); ++i) {
|
||||
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>((&result)->second[i])));
|
||||
|
@ -804,6 +804,10 @@ class TestSentencepieceProcessor(unittest.TestCase):
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC', with_offsets=True)
|
||||
self.assertEqual('▁KADOKAWAABC', x[0])
|
||||
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True)
|
||||
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0])
|
||||
self.assertEqual(
|
||||
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1]
|
||||
)
|
||||
@ -815,14 +819,23 @@ class TestSentencepieceProcessor(unittest.TestCase):
|
||||
['▁KADOKAWAABC', '▁平成'], sp.Normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
|
||||
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
|
||||
x = sp.Normalize(
|
||||
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')],
|
||||
with_offsets=True,
|
||||
)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('▁KADOKAWAABC', x[0][0])
|
||||
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0][0])
|
||||
self.assertEqual(
|
||||
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1]
|
||||
)
|
||||
|
||||
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('▁KADOKAWAABC', x[0][0])
|
||||
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
|
||||
|
||||
self.assertEqual('▁平成', x[1][0])
|
||||
self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0, 3], x[1][1])
|
||||
self.assertEqual([0, 0, 0, 1], x[1][1])
|
||||
|
||||
def test_normalizer(self):
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
@ -832,9 +845,13 @@ class TestSentencepieceProcessor(unittest.TestCase):
|
||||
self.assertEqual('KADOKAWAABC', sp.normalize('KADOKAWAABC'))
|
||||
self.assertEqual('KADOKAWAABC', sp.Normalize('KADOKAWAABC'))
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True)
|
||||
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC', with_offsets=True)
|
||||
self.assertEqual('KADOKAWAABC', x[0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
|
||||
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
|
||||
|
||||
self.assertEqual(
|
||||
['KADOKAWAABC', '平成'], sp.normalize(['KADOKAWAABC', '㍻'])
|
||||
@ -843,13 +860,20 @@ class TestSentencepieceProcessor(unittest.TestCase):
|
||||
['KADOKAWAABC', '平成'], sp.Normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
|
||||
x = sp.Normalize(
|
||||
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')],
|
||||
with_offsets=True,
|
||||
)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0][0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
|
||||
|
||||
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('KADOKAWAABC', x[0][0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
|
||||
|
||||
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
|
||||
self.assertEqual('平成', x[1][0])
|
||||
self.assertEqual([0, 0, 0, 0, 0, 0, 3], x[1][1])
|
||||
self.assertEqual([0, 0, 1], x[1][1])
|
||||
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
|
@ -366,4 +366,35 @@ std::string SentencePieceNormalizer::serialized_model_proto() const {
|
||||
return model_proto_ ? model_proto_->SerializeAsString() : "";
|
||||
}
|
||||
|
||||
void ConvertToUnicodeAlignment(absl::string_view orig, absl::string_view norm,
|
||||
std::vector<size_t> *norm_to_orig) {
|
||||
auto utf8_to_unicode_offsets = [](absl::string_view str) {
|
||||
std::vector<int> utf8_to_unicode(str.size() + 1, 0);
|
||||
size_t prev = 0;
|
||||
int ulen = 0;
|
||||
while (!str.empty()) {
|
||||
const size_t mblen =
|
||||
std::max<int>(1, string_util::OneCharLen(str.data()));
|
||||
for (int i = prev; i < prev + mblen; ++i) {
|
||||
utf8_to_unicode[i] = ulen;
|
||||
}
|
||||
++ulen;
|
||||
prev += mblen;
|
||||
str.remove_prefix(mblen);
|
||||
}
|
||||
utf8_to_unicode[prev] = ulen;
|
||||
return utf8_to_unicode;
|
||||
};
|
||||
|
||||
const auto orig_offsets = utf8_to_unicode_offsets(orig);
|
||||
const auto norm_offsets = utf8_to_unicode_offsets(norm);
|
||||
if (orig_offsets.empty() || norm_offsets.empty()) return;
|
||||
|
||||
std::vector<size_t> result(norm_offsets.back() + 1, 0);
|
||||
for (int i = 0; i < norm_to_orig->size(); ++i) {
|
||||
result[norm_offsets[i]] = orig_offsets[(*norm_to_orig)[i]];
|
||||
}
|
||||
*norm_to_orig = std::move(result);
|
||||
}
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
@ -197,6 +197,10 @@ class SentencePieceNormalizer {
|
||||
std::unique_ptr<ModelProto> model_proto_;
|
||||
};
|
||||
|
||||
// Converts the utf8 byte spans into Unicode char span.
|
||||
void ConvertToUnicodeAlignment(absl::string_view orig, absl::string_view norm,
|
||||
std::vector<size_t> *norm_to_orig);
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
||||
#endif // SENTENCEPIECE_TRAINER_H_
|
||||
|
@ -376,32 +376,33 @@ TEST(SentencePieceTrainerTest, NormalizationTest) {
|
||||
trainer_spec.set_vocab_size(1000);
|
||||
ASSERT_TRUE(SentencePieceTrainer::Train(trainer_spec).ok());
|
||||
|
||||
constexpr absl::string_view kInput = "KADOKAWA ABC ";
|
||||
|
||||
{
|
||||
SentencePieceProcessor sp;
|
||||
EXPECT_OK(sp.Load(model_file));
|
||||
EXPECT_EQ(sp.Normalize("KADOKAWA ABC "), "▁KADOKAWA▁ABC");
|
||||
EXPECT_EQ(sp.Normalize(kInput), "▁KADOKAWA▁ABC");
|
||||
|
||||
std::string normalized;
|
||||
std::vector<size_t> offsets;
|
||||
|
||||
EXPECT_OK(sp.Normalize("KADOKAWA ABC ", &normalized, &offsets));
|
||||
EXPECT_OK(sp.Normalize(kInput, &normalized, &offsets));
|
||||
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
|
||||
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
|
||||
24, 24, 24, 27, 28, 29, 30}));
|
||||
}
|
||||
ConvertToUnicodeAlignment(kInput, normalized, &offsets);
|
||||
EXPECT_EQ(offsets, std::vector<size_t>(
|
||||
{0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14}));
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_OK(sp.Load(model_file));
|
||||
EXPECT_EQ(sp.Normalize("KADOKAWA ABC "), "▁KADOKAWA▁ABC");
|
||||
EXPECT_OK(sp.Normalize("㍻元年", &normalized, &offsets));
|
||||
EXPECT_EQ(normalized, "▁平成元年");
|
||||
ConvertToUnicodeAlignment(kInput, normalized, &offsets);
|
||||
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 1, 2, 3}));
|
||||
|
||||
std::string normalized;
|
||||
std::vector<size_t> offsets;
|
||||
|
||||
EXPECT_OK(sp.Normalize("KADOKAWA ABC ", &normalized, &offsets));
|
||||
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
|
||||
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
|
||||
24, 24, 24, 27, 28, 29, 30}));
|
||||
EXPECT_OK(sp.Normalize("ガイダンス", &normalized, &offsets));
|
||||
EXPECT_EQ(normalized, "▁ガイダンス");
|
||||
ConvertToUnicodeAlignment(kInput, normalized, &offsets);
|
||||
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 2, 3, 5, 6, 7}));
|
||||
}
|
||||
|
||||
auto set_normalization_only = [](SentencePieceNormalizer *normalizer) {
|
||||
@ -417,7 +418,7 @@ TEST(SentencePieceTrainerTest, NormalizationTest) {
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_OK(sp.Load(model_file));
|
||||
set_normalization_only(&sp);
|
||||
EXPECT_EQ(sp.Normalize("KADOKAWA ABC "), "KADOKAWA ABC ");
|
||||
EXPECT_EQ(sp.Normalize(kInput), "KADOKAWA ABC ");
|
||||
}
|
||||
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user