added functionality to override normalizer spec

This commit is contained in:
Taku Kudo 2024-01-16 04:06:05 +00:00
parent 0018af1f31
commit de1747bbd4
8 changed files with 135 additions and 2 deletions

View File

@ -1 +1 @@
0.2.00
0.2.0

View File

@ -399,6 +399,9 @@ class SentencePieceProcessor(object):
def _CalculateEntropyBatch(self, ins, alpha, num_threads):
return _sentencepiece.SentencePieceProcessor__CalculateEntropyBatch(self, ins, alpha, num_threads)
def _OverrideNormalizerSpec(self, args):
return _sentencepiece.SentencePieceProcessor__OverrideNormalizerSpec(self, args)
def Init(self,
model_file=None,
model_proto=None,
@ -875,6 +878,12 @@ class SentencePieceProcessor(object):
return [_normalize(x) for x in input]
return _normalize(input)
def OverrideNormalizerSpec(self, **kwargs):
new_kwargs = {}
for key, value in kwargs.items():
new_kwargs[key] = str(value)
return self._OverrideNormalizerSpec(new_kwargs)
def piece_size(self):
return self.GetPieceSize()

View File

@ -1 +1 @@
__version__ = '0.2.00'
__version__ = '0.2.0'

View File

@ -351,6 +351,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%ignore sentencepiece::SentencePieceProcessor::NormalizeWithOffsets;
%ignore sentencepiece::SentencePieceProcessor::model_proto;
%ignore sentencepiece::SentencePieceProcessor::mutable_normalizer_spec;
%ignore sentencepiece::SentencePieceProcessor::Load;
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
%ignore sentencepiece::SentencePieceProcessor::SetModel;
@ -690,6 +691,19 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
return outs;
}
// override normalizer_spec
sentencepiece::util::Status _OverrideNormalizerSpec(
const std::unordered_map<std::string, std::string> &args) {
sentencepiece::util::Status status;
for (const auto &[key, value] : args) {
status = sentencepiece::SentencePieceTrainer::SetProtoField(
key, value,
$self->mutable_normalizer_spec());
if (!status.ok()) return status;
}
return status;
}
%pythoncode {
def Init(self,
model_file=None,
@ -1167,6 +1181,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
return [_normalize(x) for x in input]
return _normalize(input)
def OverrideNormalizerSpec(self, **kwargs):
new_kwargs = {}
for key, value in kwargs.items():
new_kwargs[key] = str(value)
return self._OverrideNormalizerSpec(new_kwargs)
def piece_size(self):
return self.GetPieceSize()

View File

@ -4033,6 +4033,16 @@ SWIGINTERN std::vector< float > sentencepiece_SentencePieceProcessor__CalculateE
}
return outs;
}
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(sentencepiece::SentencePieceProcessor *self,std::unordered_map< std::string,std::string > const &args){
sentencepiece::util::Status status;
for (const auto &[key, value] : args) {
status = sentencepiece::SentencePieceTrainer::SetProtoField(
key, value,
self->mutable_normalizer_spec());
if (!status.ok()) return status;
}
return status;
}
SWIGINTERN int
SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val)
@ -8508,6 +8518,72 @@ fail:
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor__OverrideNormalizerSpec(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
std::unordered_map< std::string,std::string > *arg2 = 0 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::Status result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__OverrideNormalizerSpec", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__OverrideNormalizerSpec" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
{
std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check(swig_obj[1])) {
PyObject *key, *value;
Py_ssize_t pos = 0;
out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next(swig_obj[1], &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
if (key_ustring.IsAvalable() && value_ustring.IsAvalable()) {
out->emplace(std::string(key_ustring.data(), key_ustring.size()),
std::string(value_ustring.data(), value_ustring.size()));
} else {
PyErr_SetString(PyExc_TypeError, "map must contain strings.");
SWIG_fail;
}
resultobj = key_ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError, "not a dictionary");
SWIG_fail;
}
arg2 = out;
}
{
try {
result = sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(arg1,(std::unordered_map< std::string,std::string > const &)*arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
{
delete arg2;
}
return resultobj;
fail:
{
delete arg2;
}
return NULL;
}
SWIGINTERN PyObject *SentencePieceProcessor_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *obj;
if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL;
@ -9362,6 +9438,7 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceProcessor__NormalizeWithOffsets", _wrap_SentencePieceProcessor__NormalizeWithOffsets, METH_VARARGS, NULL},
{ "SentencePieceProcessor__CalculateEntropy", _wrap_SentencePieceProcessor__CalculateEntropy, METH_VARARGS, NULL},
{ "SentencePieceProcessor__CalculateEntropyBatch", _wrap_SentencePieceProcessor__CalculateEntropyBatch, METH_VARARGS, NULL},
{ "SentencePieceProcessor__OverrideNormalizerSpec", _wrap_SentencePieceProcessor__OverrideNormalizerSpec, METH_VARARGS, NULL},
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
{ "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL},

View File

@ -848,6 +848,23 @@ class TestSentencepieceProcessor(unittest.TestCase):
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf')
self.assertEqual('abc', sp.Normalize(''))
def test_override_normalize_spec(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model')
)
self.assertEqual(
sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world']
)
sp.override_normalizer_spec(add_dummy_prefix=False)
sp.override_normalizer_spec(remove_extra_whitespaces=False)
sp.override_normalizer_spec(escape_whitespaces=False)
self.assertEqual(
sp.EncodeAsPieces(' hello world '),
[' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '],
)
def suite():
suite = unittest.TestSuite()

View File

@ -1117,6 +1117,10 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
return model_proto_ ? model_proto_->SerializeAsString() : "";
}
NormalizerSpec *SentencePieceProcessor::mutable_normalizer_spec() const {
return model_proto_ ? model_proto_->mutable_normalizer_spec() : nullptr;
}
// Set seed value of random generator.
// Do not set static_cast<unique_int>(-1),
// as this seed is reserved for initializing from

View File

@ -134,6 +134,7 @@ class NBestSentencePieceText;
class ModelInterface;
class SentencePieceText;
class ModelProto;
class NormalizerSpec;
namespace normalizer {
class Normalizer;
@ -692,6 +693,11 @@ class SentencePieceProcessor {
// Useful to save the state of this instance via Python's pickle object.
util::bytes serialized_model_proto() const;
// Returns mutable normalizer_spec.
// Updating the intenral normalization during the encoding/decoding are not
// recommended and may result in unexpected behavior. Use at your own risk.
NormalizerSpec *mutable_normalizer_spec() const;
private:
enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };