diff --git a/VERSION.txt b/VERSION.txt index faeda78..0ea3a94 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.2.00 +0.2.0 diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index 4a8c96d..5e24d54 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -399,6 +399,9 @@ class SentencePieceProcessor(object): def _CalculateEntropyBatch(self, ins, alpha, num_threads): return _sentencepiece.SentencePieceProcessor__CalculateEntropyBatch(self, ins, alpha, num_threads) + def _OverrideNormalizerSpec(self, args): + return _sentencepiece.SentencePieceProcessor__OverrideNormalizerSpec(self, args) + def Init(self, model_file=None, model_proto=None, @@ -875,6 +878,12 @@ class SentencePieceProcessor(object): return [_normalize(x) for x in input] return _normalize(input) + def OverrideNormalizerSpec(self, **kwargs): + new_kwargs = {} + for key, value in kwargs.items(): + new_kwargs[key] = str(value) + return self._OverrideNormalizerSpec(new_kwargs) + def piece_size(self): return self.GetPieceSize() diff --git a/python/src/sentencepiece/_version.py b/python/src/sentencepiece/_version.py index 83130b8..7fd229a 100644 --- a/python/src/sentencepiece/_version.py +++ b/python/src/sentencepiece/_version.py @@ -1 +1 @@ -__version__ = '0.2.00' +__version__ = '0.2.0' diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 4dcdc18..4323f76 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -351,6 +351,7 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { %ignore sentencepiece::SentencePieceProcessor::NormalizeWithOffsets; %ignore sentencepiece::SentencePieceProcessor::model_proto; +%ignore sentencepiece::SentencePieceProcessor::mutable_normalizer_spec; %ignore sentencepiece::SentencePieceProcessor::Load; %ignore sentencepiece::SentencePieceProcessor::LoadOrDie; %ignore sentencepiece::SentencePieceProcessor::SetModel; @@ -690,6 +691,19 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return outs; } + // override normalizer_spec + sentencepiece::util::Status _OverrideNormalizerSpec( + const std::unordered_map &args) { + sentencepiece::util::Status status; + for (const auto &[key, value] : args) { + status = sentencepiece::SentencePieceTrainer::SetProtoField( + key, value, + $self->mutable_normalizer_spec()); + if (!status.ok()) return status; + } + return status; + } + %pythoncode { def Init(self, model_file=None, @@ -1167,6 +1181,12 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return [_normalize(x) for x in input] return _normalize(input) + def OverrideNormalizerSpec(self, **kwargs): + new_kwargs = {} + for key, value in kwargs.items(): + new_kwargs[key] = str(value) + return self._OverrideNormalizerSpec(new_kwargs) + def piece_size(self): return self.GetPieceSize() diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index 720c93a..b08a543 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -4033,6 +4033,16 @@ SWIGINTERN std::vector< float > sentencepiece_SentencePieceProcessor__CalculateE } return outs; } +SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(sentencepiece::SentencePieceProcessor *self,std::unordered_map< std::string,std::string > const &args){ + sentencepiece::util::Status status; + for (const auto &[key, value] : args) { + status = sentencepiece::SentencePieceTrainer::SetProtoField( + key, value, + self->mutable_normalizer_spec()); + if (!status.ok()) return status; + } + return status; + } SWIGINTERN int SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val) @@ -8508,6 +8518,72 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor__OverrideNormalizerSpec(PyObject *self, PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::unordered_map< std::string,std::string > *arg2 = 0 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject *swig_obj[2] ; + sentencepiece::util::Status result; + + if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__OverrideNormalizerSpec", 2, 2, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__OverrideNormalizerSpec" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::unordered_map *out = nullptr; + if (PyDict_Check(swig_obj[1])) { + PyObject *key, *value; + Py_ssize_t pos = 0; + out = new std::unordered_map; + while (PyDict_Next(swig_obj[1], &pos, &key, &value)) { + const PyInputString key_ustring(key); + const PyInputString value_ustring(value); + if (key_ustring.IsAvalable() && value_ustring.IsAvalable()) { + out->emplace(std::string(key_ustring.data(), key_ustring.size()), + std::string(value_ustring.data(), value_ustring.size())); + } else { + PyErr_SetString(PyExc_TypeError, "map must contain strings."); + SWIG_fail; + } + resultobj = key_ustring.input_type(); + } + } else { + PyErr_SetString(PyExc_TypeError, "not a dictionary"); + SWIG_fail; + } + arg2 = out; + } + { + try { + result = sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(arg1,(std::unordered_map< std::string,std::string > const &)*arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + if (!(&result)->ok()) { + SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str()); + } + resultobj = SWIG_From_bool((&result)->ok()); + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + SWIGINTERN PyObject *SentencePieceProcessor_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *obj; if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL; @@ -9362,6 +9438,7 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor__NormalizeWithOffsets", _wrap_SentencePieceProcessor__NormalizeWithOffsets, METH_VARARGS, NULL}, { "SentencePieceProcessor__CalculateEntropy", _wrap_SentencePieceProcessor__CalculateEntropy, METH_VARARGS, NULL}, { "SentencePieceProcessor__CalculateEntropyBatch", _wrap_SentencePieceProcessor__CalculateEntropyBatch, METH_VARARGS, NULL}, + { "SentencePieceProcessor__OverrideNormalizerSpec", _wrap_SentencePieceProcessor__OverrideNormalizerSpec, METH_VARARGS, NULL}, { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, { "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL}, { "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL}, diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 67e272b..9dd91a7 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -848,6 +848,23 @@ class TestSentencepieceProcessor(unittest.TestCase): sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf') self.assertEqual('abc', sp.Normalize('ABC')) + def test_override_normalize_spec(self): + sp = spm.SentencePieceProcessor( + model_file=os.path.join('test', 'test_model.model') + ) + + self.assertEqual( + sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world'] + ) + + sp.override_normalizer_spec(add_dummy_prefix=False) + sp.override_normalizer_spec(remove_extra_whitespaces=False) + sp.override_normalizer_spec(escape_whitespaces=False) + self.assertEqual( + sp.EncodeAsPieces(' hello world '), + [' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '], + ) + def suite(): suite = unittest.TestSuite() diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 7b1951f..5d2c857 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -1117,6 +1117,10 @@ std::string SentencePieceProcessor::serialized_model_proto() const { return model_proto_ ? model_proto_->SerializeAsString() : ""; } +NormalizerSpec *SentencePieceProcessor::mutable_normalizer_spec() const { + return model_proto_ ? model_proto_->mutable_normalizer_spec() : nullptr; +} + // Set seed value of random generator. // Do not set static_cast(-1), // as this seed is reserved for initializing from diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 1892caa..dd3f092 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -134,6 +134,7 @@ class NBestSentencePieceText; class ModelInterface; class SentencePieceText; class ModelProto; +class NormalizerSpec; namespace normalizer { class Normalizer; @@ -692,6 +693,11 @@ class SentencePieceProcessor { // Useful to save the state of this instance via Python's pickle object. util::bytes serialized_model_proto() const; + // Returns mutable normalizer_spec. + // Updating the intenral normalization during the encoding/decoding are not + // recommended and may result in unexpected behavior. Use at your own risk. + NormalizerSpec *mutable_normalizer_spec() const; + private: enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };