mirror of
https://github.com/google/sentencepiece.git
synced 2024-12-29 11:11:58 +03:00
added functionality to override normalizer spec
This commit is contained in:
parent
0018af1f31
commit
de1747bbd4
@ -1 +1 @@
|
||||
0.2.00
|
||||
0.2.0
|
||||
|
@ -399,6 +399,9 @@ class SentencePieceProcessor(object):
|
||||
def _CalculateEntropyBatch(self, ins, alpha, num_threads):
|
||||
return _sentencepiece.SentencePieceProcessor__CalculateEntropyBatch(self, ins, alpha, num_threads)
|
||||
|
||||
def _OverrideNormalizerSpec(self, args):
|
||||
return _sentencepiece.SentencePieceProcessor__OverrideNormalizerSpec(self, args)
|
||||
|
||||
def Init(self,
|
||||
model_file=None,
|
||||
model_proto=None,
|
||||
@ -875,6 +878,12 @@ class SentencePieceProcessor(object):
|
||||
return [_normalize(x) for x in input]
|
||||
return _normalize(input)
|
||||
|
||||
def OverrideNormalizerSpec(self, **kwargs):
|
||||
new_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
new_kwargs[key] = str(value)
|
||||
return self._OverrideNormalizerSpec(new_kwargs)
|
||||
|
||||
|
||||
def piece_size(self):
|
||||
return self.GetPieceSize()
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.2.00'
|
||||
__version__ = '0.2.0'
|
||||
|
@ -351,6 +351,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
|
||||
%ignore sentencepiece::SentencePieceProcessor::NormalizeWithOffsets;
|
||||
|
||||
%ignore sentencepiece::SentencePieceProcessor::model_proto;
|
||||
%ignore sentencepiece::SentencePieceProcessor::mutable_normalizer_spec;
|
||||
%ignore sentencepiece::SentencePieceProcessor::Load;
|
||||
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
|
||||
%ignore sentencepiece::SentencePieceProcessor::SetModel;
|
||||
@ -690,6 +691,19 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
|
||||
return outs;
|
||||
}
|
||||
|
||||
// override normalizer_spec
|
||||
sentencepiece::util::Status _OverrideNormalizerSpec(
|
||||
const std::unordered_map<std::string, std::string> &args) {
|
||||
sentencepiece::util::Status status;
|
||||
for (const auto &[key, value] : args) {
|
||||
status = sentencepiece::SentencePieceTrainer::SetProtoField(
|
||||
key, value,
|
||||
$self->mutable_normalizer_spec());
|
||||
if (!status.ok()) return status;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
%pythoncode {
|
||||
def Init(self,
|
||||
model_file=None,
|
||||
@ -1167,6 +1181,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
|
||||
return [_normalize(x) for x in input]
|
||||
return _normalize(input)
|
||||
|
||||
def OverrideNormalizerSpec(self, **kwargs):
|
||||
new_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
new_kwargs[key] = str(value)
|
||||
return self._OverrideNormalizerSpec(new_kwargs)
|
||||
|
||||
|
||||
def piece_size(self):
|
||||
return self.GetPieceSize()
|
||||
|
@ -4033,6 +4033,16 @@ SWIGINTERN std::vector< float > sentencepiece_SentencePieceProcessor__CalculateE
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(sentencepiece::SentencePieceProcessor *self,std::unordered_map< std::string,std::string > const &args){
|
||||
sentencepiece::util::Status status;
|
||||
for (const auto &[key, value] : args) {
|
||||
status = sentencepiece::SentencePieceTrainer::SetProtoField(
|
||||
key, value,
|
||||
self->mutable_normalizer_spec());
|
||||
if (!status.ok()) return status;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
SWIGINTERN int
|
||||
SWIG_AsVal_unsigned_SS_long (PyObject *obj, unsigned long *val)
|
||||
@ -8508,6 +8518,72 @@ fail:
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceProcessor__OverrideNormalizerSpec(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
|
||||
std::unordered_map< std::string,std::string > *arg2 = 0 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
sentencepiece::util::Status result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__OverrideNormalizerSpec", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__OverrideNormalizerSpec" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
|
||||
{
|
||||
std::unordered_map<std::string, std::string> *out = nullptr;
|
||||
if (PyDict_Check(swig_obj[1])) {
|
||||
PyObject *key, *value;
|
||||
Py_ssize_t pos = 0;
|
||||
out = new std::unordered_map<std::string, std::string>;
|
||||
while (PyDict_Next(swig_obj[1], &pos, &key, &value)) {
|
||||
const PyInputString key_ustring(key);
|
||||
const PyInputString value_ustring(value);
|
||||
if (key_ustring.IsAvalable() && value_ustring.IsAvalable()) {
|
||||
out->emplace(std::string(key_ustring.data(), key_ustring.size()),
|
||||
std::string(value_ustring.data(), value_ustring.size()));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "map must contain strings.");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = key_ustring.input_type();
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError, "not a dictionary");
|
||||
SWIG_fail;
|
||||
}
|
||||
arg2 = out;
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = sentencepiece_SentencePieceProcessor__OverrideNormalizerSpec(arg1,(std::unordered_map< std::string,std::string > const &)*arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
if (!(&result)->ok()) {
|
||||
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
|
||||
}
|
||||
resultobj = SWIG_From_bool((&result)->ok());
|
||||
}
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *SentencePieceProcessor_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *obj;
|
||||
if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL;
|
||||
@ -9362,6 +9438,7 @@ static PyMethodDef SwigMethods[] = {
|
||||
{ "SentencePieceProcessor__NormalizeWithOffsets", _wrap_SentencePieceProcessor__NormalizeWithOffsets, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor__CalculateEntropy", _wrap_SentencePieceProcessor__CalculateEntropy, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor__CalculateEntropyBatch", _wrap_SentencePieceProcessor__CalculateEntropyBatch, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor__OverrideNormalizerSpec", _wrap_SentencePieceProcessor__OverrideNormalizerSpec, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
|
||||
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
|
||||
{ "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL},
|
||||
|
@ -848,6 +848,23 @@ class TestSentencepieceProcessor(unittest.TestCase):
|
||||
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf')
|
||||
self.assertEqual('abc', sp.Normalize('ABC'))
|
||||
|
||||
def test_override_normalize_spec(self):
|
||||
sp = spm.SentencePieceProcessor(
|
||||
model_file=os.path.join('test', 'test_model.model')
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world']
|
||||
)
|
||||
|
||||
sp.override_normalizer_spec(add_dummy_prefix=False)
|
||||
sp.override_normalizer_spec(remove_extra_whitespaces=False)
|
||||
sp.override_normalizer_spec(escape_whitespaces=False)
|
||||
self.assertEqual(
|
||||
sp.EncodeAsPieces(' hello world '),
|
||||
[' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '],
|
||||
)
|
||||
|
||||
|
||||
def suite():
|
||||
suite = unittest.TestSuite()
|
||||
|
@ -1117,6 +1117,10 @@ std::string SentencePieceProcessor::serialized_model_proto() const {
|
||||
return model_proto_ ? model_proto_->SerializeAsString() : "";
|
||||
}
|
||||
|
||||
NormalizerSpec *SentencePieceProcessor::mutable_normalizer_spec() const {
|
||||
return model_proto_ ? model_proto_->mutable_normalizer_spec() : nullptr;
|
||||
}
|
||||
|
||||
// Set seed value of random generator.
|
||||
// Do not set static_cast<unique_int>(-1),
|
||||
// as this seed is reserved for initializing from
|
||||
|
@ -134,6 +134,7 @@ class NBestSentencePieceText;
|
||||
class ModelInterface;
|
||||
class SentencePieceText;
|
||||
class ModelProto;
|
||||
class NormalizerSpec;
|
||||
|
||||
namespace normalizer {
|
||||
class Normalizer;
|
||||
@ -692,6 +693,11 @@ class SentencePieceProcessor {
|
||||
// Useful to save the state of this instance via Python's pickle object.
|
||||
util::bytes serialized_model_proto() const;
|
||||
|
||||
// Returns mutable normalizer_spec.
|
||||
// Updating the intenral normalization during the encoding/decoding are not
|
||||
// recommended and may result in unexpected behavior. Use at your own risk.
|
||||
NormalizerSpec *mutable_normalizer_spec() const;
|
||||
|
||||
private:
|
||||
enum ExtraOption { REVERSE, BOS, EOS, UNK_PIECE };
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user