From 06eee098476885c6cc46a2f23839ed930f1657fe Mon Sep 17 00:00:00 2001 From: Taku Kudo Date: Thu, 4 Jan 2024 09:04:20 +0000 Subject: [PATCH] Added Normalization API --- python/src/sentencepiece/__init__.py | 17 +++ python/src/sentencepiece/sentencepiece.i | 41 ++++++- .../src/sentencepiece/sentencepiece_wrap.cxx | 100 ++++++++++++++++++ python/test/sentencepiece_test.py | 30 ++++++ src/sentencepiece_processor.cc | 20 ++++ src/sentencepiece_processor.h | 15 +++ 6 files changed, 219 insertions(+), 4 deletions(-) diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index 2bfd645..acf1490 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -387,6 +387,12 @@ class SentencePieceProcessor(object): def _SampleEncodeAndScoreAsImmutableProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece): return _sentencepiece.SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece) + def _Normalize(self, text): + return _sentencepiece.SentencePieceProcessor__Normalize(self, text) + + def _NormalizeWithOffsets(self, text): + return _sentencepiece.SentencePieceProcessor__NormalizeWithOffsets(self, text) + def _CalculateEntropy(self, text, alpha): return _sentencepiece.SentencePieceProcessor__CalculateEntropy(self, text, alpha) @@ -859,6 +865,17 @@ class SentencePieceProcessor(object): return self._CalculateEntropy(input, alpha) + def Normalize(self, input, with_offsets=None): + def _normalize(text): + if with_offsets: + return self._NormalizeWithOffsets(text) + return self._Normalize(text) + + if type(input) is list: + return [_normalize(x) for x in input] + return _normalize(input) + + def piece_size(self): return self.GetPieceSize() diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 5b28abc..eca2948 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -347,6 +347,9 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { %ignore sentencepiece::SentencePieceProcessor::DecodePiecesAsImmutableProto; %ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsImmutableProto; +%ignore sentencepiece::SentencePieceProcessor::Normalize; +%ignore sentencepiece::SentencePieceProcessor::NormalizeWithOffsets; + %ignore sentencepiece::SentencePieceProcessor::model_proto; %ignore sentencepiece::SentencePieceProcessor::Load; %ignore sentencepiece::SentencePieceProcessor::LoadOrDie; @@ -648,6 +651,16 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return proto; } + // Normalize + std::string _Normalize(absl::string_view text) { + return $self->Normalize(text); + } + + std::pair> _NormalizeWithOffsets(absl::string_view text) { + std::pair> result; + $self->Normalize(text, &result.first, &result.second).IgnoreError(); + return result; + } // Calculate Entropy float _CalculateEntropy(absl::string_view text, float alpha) { @@ -1020,12 +1033,12 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { def SampleEncodeAndScoreAsSerializedProto(self, input, num_samples=None, alpha=None, **kwargs): return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha, out_type='serialized_proto', **kwargs) - + def SampleEncodeAndScoreAsImmutableProto(self, input, num_samples=None, alpha=None, **kwargs): return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha, out_type='immutable_proto', **kwargs) - + def Decode(self, input, out_type=str, num_threads=None): """Decode processed id or token sequences. @@ -1140,6 +1153,17 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return self._CalculateEntropy(input, alpha) + def Normalize(self, input, with_offsets=None): + def _normalize(text): + if with_offsets: + return self._NormalizeWithOffsets(text) + return self._Normalize(text) + + if type(input) is list: + return [_normalize(x) for x in input] + return _normalize(input) + + def piece_size(self): return self.GetPieceSize() @@ -1315,7 +1339,7 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { def __init__(self, proto): self.proto = proto self.len = self.proto._pieces_size() - + def __len__(self): return self.len @@ -1383,7 +1407,7 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { @property def nbests(self): return ImmutableNBestSentencePieceText.ImmutableSentencePieceTextIterator(self) - + def __eq__(self, other): return self.SerializeAsString() == other.SerializeAsString() @@ -1654,6 +1678,15 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { } } +%typemap(out) std::pair> { + PyObject *input_type = resultobj; + PyObject *obj = PyList_New($1.second.size()); + for (size_t i = 0; i < $1.second.size(); ++i) { + PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast($1.second[i]))); + } + $result = PyTuple_Pack(2, MakePyOutputString($1.first, input_type), obj); +} + %typemap(in) sentencepiece::SentenceIterator * { sentencepiece::SentenceIterator *out = nullptr; if (PyIter_Check($input)) { diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index 753b2e2..6691137 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -4004,6 +4004,14 @@ SWIGINTERN sentencepiece::ImmutableNBestSentencePieceText sentencepiece_Sentence proto.ConvertToUnicodeSpans(); return proto; } +SWIGINTERN std::string sentencepiece_SentencePieceProcessor__Normalize(sentencepiece::SentencePieceProcessor *self,absl::string_view text){ + return self->Normalize(text); + } +SWIGINTERN std::pair< std::string,std::vector< size_t > > sentencepiece_SentencePieceProcessor__NormalizeWithOffsets(sentencepiece::SentencePieceProcessor *self,absl::string_view text){ + std::pair> result; + self->Normalize(text, &result.first, &result.second).IgnoreError(); + return result; + } SWIGINTERN float sentencepiece_SentencePieceProcessor__CalculateEntropy(sentencepiece::SentencePieceProcessor *self,absl::string_view text,float alpha){ return self->CalculateEntropy(text, alpha); } @@ -8261,6 +8269,96 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor__Normalize(PyObject *self, PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + absl::string_view arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject *swig_obj[2] ; + std::string result; + + if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__Normalize", 2, 2, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__Normalize" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(swig_obj[1]); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = ustring.str(); + } + { + try { + result = sentencepiece_SentencePieceProcessor__Normalize(arg1,SWIG_STD_MOVE(arg2)); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + PyObject *input_type = resultobj; + resultobj = MakePyOutputString(result, input_type); + } + return resultobj; +fail: + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor__NormalizeWithOffsets(PyObject *self, PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + absl::string_view arg2 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject *swig_obj[2] ; + std::pair< std::string,std::vector< size_t > > result; + + if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__NormalizeWithOffsets", 2, 2, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__NormalizeWithOffsets" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + const PyInputString ustring(swig_obj[1]); + if (!ustring.IsAvalable()) { + PyErr_SetString(PyExc_TypeError, "not a string"); + SWIG_fail; + } + resultobj = ustring.input_type(); + arg2 = ustring.str(); + } + { + try { + result = sentencepiece_SentencePieceProcessor__NormalizeWithOffsets(arg1,SWIG_STD_MOVE(arg2)); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + PyObject *input_type = resultobj; + PyObject *obj = PyList_New((&result)->second.size()); + for (size_t i = 0; i < (&result)->second.size(); ++i) { + PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast((&result)->second[i]))); + } + resultobj = PyTuple_Pack(2, MakePyOutputString((&result)->first, input_type), obj); + } + return resultobj; +fail: + return NULL; +} + + SWIGINTERN PyObject *_wrap_SentencePieceProcessor__CalculateEntropy(PyObject *self, PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -8825,6 +8923,8 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor__SampleEncodeAndScoreAsPieces", _wrap_SentencePieceProcessor__SampleEncodeAndScoreAsPieces, METH_VARARGS, NULL}, { "SentencePieceProcessor__SampleEncodeAndScoreAsSerializedProto", _wrap_SentencePieceProcessor__SampleEncodeAndScoreAsSerializedProto, METH_VARARGS, NULL}, { "SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto", _wrap_SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto, METH_VARARGS, NULL}, + { "SentencePieceProcessor__Normalize", _wrap_SentencePieceProcessor__Normalize, METH_VARARGS, NULL}, + { "SentencePieceProcessor__NormalizeWithOffsets", _wrap_SentencePieceProcessor__NormalizeWithOffsets, METH_VARARGS, NULL}, { "SentencePieceProcessor__CalculateEntropy", _wrap_SentencePieceProcessor__CalculateEntropy, METH_VARARGS, NULL}, { "SentencePieceProcessor__CalculateEntropyBatch", _wrap_SentencePieceProcessor__CalculateEntropyBatch, METH_VARARGS, NULL}, { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index adbc607..46288b9 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -760,6 +760,36 @@ class TestSentencepieceProcessor(unittest.TestCase): spm.set_random_generator_seed(1) spm.set_min_log_level(3) + def test_normalize(self): + sp = spm.SentencePieceProcessor( + model_file=os.path.join('test', 'test_model.model') + ) + + self.assertEqual('▁KADOKAWAABC', sp.normalize('KADOKAWAABC')) + self.assertEqual('▁KADOKAWAABC', sp.Normalize('KADOKAWAABC')) + + x = sp.Normalize('KADOKAWAABC', with_offsets=True) + self.assertEqual('▁KADOKAWAABC', x[0]) + self.assertEqual( + [0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1] + ) + + self.assertEqual( + ['▁KADOKAWAABC', '▁平成'], sp.normalize(['KADOKAWAABC', '㍻']) + ) + self.assertEqual( + ['▁KADOKAWAABC', '▁平成'], sp.Normalize(['KADOKAWAABC', '㍻']) + ) + + x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True) + self.assertEqual(len(x), 2) + self.assertEqual('▁KADOKAWAABC', x[0][0]) + self.assertEqual( + [0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1] + ) + self.assertEqual('▁平成', x[1][0]) + self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0, 3], x[1][1]) + def suite(): suite = unittest.TestSuite() diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index e15af96..2545ab4 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -931,6 +931,26 @@ util::Status SentencePieceProcessor::Decode(const std::vector &ids, return value; \ } +util::Status SentencePieceProcessor::Normalize(absl::string_view input, + std::string *normalized) const { + std::vector norm_to_orig; + CHECK_OR_RETURN(normalizer_); + return normalizer_->Normalize(input, normalized, &norm_to_orig); +} + +util::Status SentencePieceProcessor::Normalize( + absl::string_view input, std::string *normalized, + std::vector *norm_to_orig) const { + CHECK_OR_RETURN(normalizer_); + return normalizer_->Normalize(input, normalized, norm_to_orig); +} + +std::string SentencePieceProcessor::Normalize(absl::string_view input) const { + std::string normalized; + Normalize(input, &normalized).IgnoreError(); + return normalized; +} + int SentencePieceProcessor::GetPieceSize() const { CHECK_STATUS_OR_RETURN_DEFAULT(0); return model_->GetPieceSize(); diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 7a15517..1892caa 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -614,6 +614,21 @@ class SentencePieceProcessor { #undef DEFINE_SPP_SERIALIZED_PROTO_IMPL #undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL + ////////////////////////////////////////////////////////////// + // Normalization methods. + + // Normalize `input`. + virtual util::Status Normalize(absl::string_view input, + std::string *normalized) const; + + // Normalize `input`. Stores the utf8-byte offset from + // the normalized string to the original input. + virtual util::Status Normalize(absl::string_view input, + std::string *normalized, + std::vector *norm_to_orig) const; + + virtual std::string Normalize(absl::string_view input) const; + ////////////////////////////////////////////////////////////// // Vocabulary management methods. //