add more advanced SentencePieceNormalizer class

This commit is contained in:
Taku Kudo 2024-01-13 17:19:50 +00:00
parent f5c736302c
commit ed76ecc478
7 changed files with 906 additions and 10 deletions

View File

@ -1004,6 +1004,98 @@ class SentencePieceTrainer(object):
# Register SentencePieceTrainer in _sentencepiece:
_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)
class SentencePieceNormalizer(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
__repr__ = _swig_repr
def __init__(self):
_sentencepiece.SentencePieceNormalizer_swiginit(self, _sentencepiece.new_SentencePieceNormalizer())
__swig_destroy__ = _sentencepiece.delete_SentencePieceNormalizer
def LoadFromSerializedProto(self, serialized):
return _sentencepiece.SentencePieceNormalizer_LoadFromSerializedProto(self, serialized)
def LoadFromRuleTSV(self, filename):
return _sentencepiece.SentencePieceNormalizer_LoadFromRuleTSV(self, filename)
def LoadFromRuleName(self, name):
return _sentencepiece.SentencePieceNormalizer_LoadFromRuleName(self, name)
def serialized_model_proto(self):
return _sentencepiece.SentencePieceNormalizer_serialized_model_proto(self)
def LoadFromFile(self, arg):
return _sentencepiece.SentencePieceNormalizer_LoadFromFile(self, arg)
def _Normalize(self, text):
return _sentencepiece.SentencePieceNormalizer__Normalize(self, text)
def _NormalizeWithOffsets(self, text):
return _sentencepiece.SentencePieceNormalizer__NormalizeWithOffsets(self, text)
def _SetProtoField(self, name, value):
return _sentencepiece.SentencePieceNormalizer__SetProtoField(self, name, value)
def Init(self,
model_file=None,
model_proto=None,
rule_tsv=None,
rule_name=None,
add_dummy_prefix=False,
escape_whitespaces=False,
remove_extra_whitespaces=False):
"""Initialzie sentencePieceNormalizer.
Args:
model_file: The sentencepiece model file path.
model_proto: The sentencepiece model serialized proto.
rule_tsv: The normalization rule file in TSV format.
rule_name: Pre-defined normalization name.
add_dummy_prefix: add dummy prefix.
escape_whitespaces: escape whitespaces.
remove_extra_whitespaces: remove extra whitespaces.
"""
_sentencepiece_normalizer_init_native(self)
if model_file:
status = self.LoadFromFile(model_file)
elif model_proto:
status = self.LoadFromSerializedProto(model_proto)
elif rule_tsv:
status = self.LoadFromRuleTSV(rule_tsv)
elif rule_name:
status = self.LoadFromRuleName(rule_name)
else:
raise RuntimeError('no model is specified')
if status:
self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
self._SetProtoField('escape_whitespaces', escape_whitespaces)
self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)
def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)
if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)
def __getstate__(self):
return self.serialized_model_proto()
def __setstate__(self, serialized_model_proto):
self.__init__()
self.LoadFromSerializedProto(serialized_model_proto)
# Register SentencePieceNormalizer in _sentencepiece:
_sentencepiece.SentencePieceNormalizer_swigregister(SentencePieceNormalizer)
import re
@ -1045,7 +1137,9 @@ def _batchnize(classname, name):
_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
_sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
@ -1058,6 +1152,7 @@ for m in [
_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
_add_snake_case(SentencePieceNormalizer)
set_random_generator_seed = SetRandomGeneratorSeed
set_min_log_level = SetMinLogLevel

View File

@ -368,6 +368,10 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%ignore sentencepiece::SentencePieceTrainer::SetPretokenizerForTraining;
%ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining;
%ignore sentencepiece::SentencePieceNormalizer::Load;
%ignore sentencepiece::SentencePieceNormalizer::Normalize;
%ignore sentencepiece::SentencePieceNormalizer::mutable_normalizer_spec;
%ignore sentencepiece::io::LoadModelProto;
%ignore sentencepiece::io::SaveModelProto;
@ -1293,6 +1297,92 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
}
}
%extend sentencepiece::SentencePieceNormalizer {
sentencepiece::util::Status LoadFromFile(absl::string_view arg) {
return $self->Load(arg);
}
std::string _Normalize(absl::string_view text) {
std::string result;
const auto _status = $self->Normalize(text, &result);
if (!_status.ok()) throw _status;
return result;
}
std::pair<std::string, std::vector<size_t>> _NormalizeWithOffsets(absl::string_view text) {
std::pair<std::string, std::vector<size_t>> result;
const auto _status = $self->Normalize(text, &result.first, &result.second);
if (!_status.ok()) throw _status;
return result;
}
void _SetProtoField(absl::string_view name, bool value) {
sentencepiece::SentencePieceTrainer::SetProtoField(
name,
value ? "1" : "0",
$self->mutable_normalizer_spec()).IgnoreError();
}
%pythoncode %{
def Init(self,
model_file=None,
model_proto=None,
rule_tsv=None,
rule_name=None,
add_dummy_prefix=False,
escape_whitespaces=False,
remove_extra_whitespaces=False):
"""Initialzie sentencePieceNormalizer.
Args:
model_file: The sentencepiece model file path.
model_proto: The sentencepiece model serialized proto.
rule_tsv: The normalization rule file in TSV format.
rule_name: Pre-defined normalization name.
add_dummy_prefix: add dummy prefix.
escape_whitespaces: escape whitespaces.
remove_extra_whitespaces: remove extra whitespaces.
"""
_sentencepiece_normalizer_init_native(self)
if model_file:
status = self.LoadFromFile(model_file)
elif model_proto:
status = self.LoadFromSerializedProto(model_proto)
elif rule_tsv:
status = self.LoadFromRuleTSV(rule_tsv)
elif rule_name:
status = self.LoadFromRuleName(rule_name)
else:
raise RuntimeError('no model is specified')
if status:
self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
self._SetProtoField('escape_whitespaces', escape_whitespaces)
self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)
def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)
if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)
def __getstate__(self):
return self.serialized_model_proto()
def __setstate__(self, serialized_model_proto):
self.__init__()
self.LoadFromSerializedProto(serialized_model_proto)
%}
}
%extend sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece {
%rename(_piece) piece;
%rename(_id) id;
@ -1790,7 +1880,9 @@ def _batchnize(classname, name):
_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
_sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
@ -1803,6 +1895,7 @@ for m in [
_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
_add_snake_case(SentencePieceNormalizer)
set_random_generator_seed = SetRandomGeneratorSeed
set_min_log_level = SetMinLogLevel

View File

@ -2987,16 +2987,17 @@ SWIG_Python_NonDynamicSetAttr(PyObject *obj, PyObject *name, PyObject *value) {
#define SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText swig_types[3]
#define SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece swig_types[4]
#define SWIGTYPE_p_sentencepiece__SentenceIterator swig_types[5]
#define SWIGTYPE_p_sentencepiece__SentencePieceProcessor swig_types[6]
#define SWIGTYPE_p_sentencepiece__SentencePieceTrainer swig_types[7]
#define SWIGTYPE_p_std__string swig_types[8]
#define SWIGTYPE_p_std__unordered_mapT_std__string_std__string_t swig_types[9]
#define SWIGTYPE_p_std__vectorT_absl__string_view_t swig_types[10]
#define SWIGTYPE_p_std__vectorT_int_t swig_types[11]
#define SWIGTYPE_p_std__vectorT_std__vectorT_absl__string_view_t_t swig_types[12]
#define SWIGTYPE_p_std__vectorT_std__vectorT_int_t_t swig_types[13]
static swig_type_info *swig_types[15];
static swig_module_info swig_module = {swig_types, 14, 0, 0, 0, 0};
#define SWIGTYPE_p_sentencepiece__SentencePieceNormalizer swig_types[6]
#define SWIGTYPE_p_sentencepiece__SentencePieceProcessor swig_types[7]
#define SWIGTYPE_p_sentencepiece__SentencePieceTrainer swig_types[8]
#define SWIGTYPE_p_std__string swig_types[9]
#define SWIGTYPE_p_std__unordered_mapT_std__string_std__string_t swig_types[10]
#define SWIGTYPE_p_std__vectorT_absl__string_view_t swig_types[11]
#define SWIGTYPE_p_std__vectorT_int_t swig_types[12]
#define SWIGTYPE_p_std__vectorT_std__vectorT_absl__string_view_t_t swig_types[13]
#define SWIGTYPE_p_std__vectorT_std__vectorT_int_t_t swig_types[14]
static swig_type_info *swig_types[16];
static swig_module_info swig_module = {swig_types, 15, 0, 0, 0, 0};
#define SWIG_TypeQuery(name) SWIG_TypeQueryModule(&swig_module, &swig_module, name)
#define SWIG_MangledTypeQuery(name) SWIG_MangledTypeQueryModule(&swig_module, &swig_module, name)
@ -4123,6 +4124,27 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainF
if (!_status.ok()) throw _status;
return model_proto;
}
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceNormalizer_LoadFromFile(sentencepiece::SentencePieceNormalizer *self,absl::string_view arg){
return self->Load(arg);
}
SWIGINTERN std::string sentencepiece_SentencePieceNormalizer__Normalize(sentencepiece::SentencePieceNormalizer *self,absl::string_view text){
std::string result;
const auto _status = self->Normalize(text, &result);
if (!_status.ok()) throw _status;
return result;
}
SWIGINTERN std::pair< std::string,std::vector< size_t > > sentencepiece_SentencePieceNormalizer__NormalizeWithOffsets(sentencepiece::SentencePieceNormalizer *self,absl::string_view text){
std::pair<std::string, std::vector<size_t>> result;
const auto _status = self->Normalize(text, &result.first, &result.second);
if (!_status.ok()) throw _status;
return result;
}
SWIGINTERN void sentencepiece_SentencePieceNormalizer__SetProtoField(sentencepiece::SentencePieceNormalizer *self,absl::string_view name,bool value){
sentencepiece::SentencePieceTrainer::SetProtoField(
name,
value ? "1" : "0",
self->mutable_normalizer_spec()).IgnoreError();
}
#ifdef __cplusplus
extern "C" {
#endif
@ -8846,6 +8868,419 @@ SWIGINTERN PyObject *SentencePieceTrainer_swigregister(PyObject *SWIGUNUSEDPARM(
return SWIG_Py_Void();
}
SWIGINTERN PyObject *_wrap_new_SentencePieceNormalizer(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *result = 0 ;
if (!SWIG_Python_UnpackTuple(args, "new_SentencePieceNormalizer", 0, 0, 0)) SWIG_fail;
{
try {
result = (sentencepiece::SentencePieceNormalizer *)new sentencepiece::SentencePieceNormalizer();
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
resultobj = SWIG_NewPointerObj(SWIG_as_voidptr(result), SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, SWIG_POINTER_NEW | 0 );
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_delete_SentencePieceNormalizer(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[1] ;
if (!args) SWIG_fail;
swig_obj[0] = args;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, SWIG_POINTER_DISOWN | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "delete_SentencePieceNormalizer" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
try {
delete arg1;
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
resultobj = SWIG_Py_Void();
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromSerializedProto(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::Status result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromSerializedProto", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = (arg1)->LoadFromSerializedProto(arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromRuleTSV(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::Status result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromRuleTSV", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromRuleTSV" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = (arg1)->LoadFromRuleTSV(arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromRuleName(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::Status result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromRuleName", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromRuleName" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = (arg1)->LoadFromRuleName(arg2);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_serialized_model_proto(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[1] ;
std::string result;
if (!args) SWIG_fail;
swig_obj[0] = args;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_serialized_model_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer const *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
try {
result = ((sentencepiece::SentencePieceNormalizer const *)arg1)->serialized_model_proto();
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
PyObject *input_type = resultobj;
resultobj = MakePyOutputString(result, input_type);
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromFile(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
sentencepiece::util::Status result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromFile", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromFile" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = sentencepiece_SentencePieceNormalizer_LoadFromFile(arg1,SWIG_STD_MOVE(arg2));
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
if (!(&result)->ok()) {
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
}
resultobj = SWIG_From_bool((&result)->ok());
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__Normalize(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
std::string result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer__Normalize", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer__Normalize" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = sentencepiece_SentencePieceNormalizer__Normalize(arg1,SWIG_STD_MOVE(arg2));
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
PyObject *input_type = resultobj;
resultobj = MakePyOutputString(result, input_type);
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__NormalizeWithOffsets(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
std::pair< std::string,std::vector< size_t > > result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer__NormalizeWithOffsets", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer__NormalizeWithOffsets" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = sentencepiece_SentencePieceNormalizer__NormalizeWithOffsets(arg1,SWIG_STD_MOVE(arg2));
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
PyObject *input_type = resultobj;
PyObject *obj = PyList_New((&result)->second.size());
for (size_t i = 0; i < (&result)->second.size(); ++i) {
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>((&result)->second[i])));
}
resultobj = PyTuple_Pack(2, MakePyOutputString((&result)->first, input_type), obj);
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__SetProtoField(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
absl::string_view arg2 ;
bool arg3 ;
void *argp1 = 0 ;
int res1 = 0 ;
bool val3 ;
int ecode3 = 0 ;
PyObject *swig_obj[3] ;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer__SetProtoField", 3, 3, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer__SetProtoField" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
ecode3 = SWIG_AsVal_bool(swig_obj[2], &val3);
if (!SWIG_IsOK(ecode3)) {
SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceNormalizer__SetProtoField" "', argument " "3"" of type '" "bool""'");
}
arg3 = static_cast< bool >(val3);
{
try {
sentencepiece_SentencePieceNormalizer__SetProtoField(arg1,SWIG_STD_MOVE(arg2),arg3);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
resultobj = SWIG_Py_Void();
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *SentencePieceNormalizer_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
PyObject *obj;
if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL;
SWIG_TypeNewClientData(SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, SWIG_NewClientData(obj));
return SWIG_Py_Void();
}
SWIGINTERN PyObject *SentencePieceNormalizer_swiginit(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
return SWIG_Python_InitShadowInstance(args);
}
static PyMethodDef SwigMethods[] = {
{ "new_ImmutableSentencePieceText_ImmutableSentencePiece", _wrap_new_ImmutableSentencePieceText_ImmutableSentencePiece, METH_NOARGS, NULL},
{ "delete_ImmutableSentencePieceText_ImmutableSentencePiece", _wrap_delete_ImmutableSentencePieceText_ImmutableSentencePiece, METH_O, NULL},
@ -8937,6 +9372,18 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceTrainer__TrainFromMap3", _wrap_SentencePieceTrainer__TrainFromMap3, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromMap4", _wrap_SentencePieceTrainer__TrainFromMap4, METH_VARARGS, NULL},
{ "SentencePieceTrainer_swigregister", SentencePieceTrainer_swigregister, METH_O, NULL},
{ "new_SentencePieceNormalizer", _wrap_new_SentencePieceNormalizer, METH_NOARGS, NULL},
{ "delete_SentencePieceNormalizer", _wrap_delete_SentencePieceNormalizer, METH_O, NULL},
{ "SentencePieceNormalizer_LoadFromSerializedProto", _wrap_SentencePieceNormalizer_LoadFromSerializedProto, METH_VARARGS, NULL},
{ "SentencePieceNormalizer_LoadFromRuleTSV", _wrap_SentencePieceNormalizer_LoadFromRuleTSV, METH_VARARGS, NULL},
{ "SentencePieceNormalizer_LoadFromRuleName", _wrap_SentencePieceNormalizer_LoadFromRuleName, METH_VARARGS, NULL},
{ "SentencePieceNormalizer_serialized_model_proto", _wrap_SentencePieceNormalizer_serialized_model_proto, METH_O, NULL},
{ "SentencePieceNormalizer_LoadFromFile", _wrap_SentencePieceNormalizer_LoadFromFile, METH_VARARGS, NULL},
{ "SentencePieceNormalizer__Normalize", _wrap_SentencePieceNormalizer__Normalize, METH_VARARGS, NULL},
{ "SentencePieceNormalizer__NormalizeWithOffsets", _wrap_SentencePieceNormalizer__NormalizeWithOffsets, METH_VARARGS, NULL},
{ "SentencePieceNormalizer__SetProtoField", _wrap_SentencePieceNormalizer__SetProtoField, METH_VARARGS, NULL},
{ "SentencePieceNormalizer_swigregister", SentencePieceNormalizer_swigregister, METH_O, NULL},
{ "SentencePieceNormalizer_swiginit", SentencePieceNormalizer_swiginit, METH_VARARGS, NULL},
{ NULL, NULL, 0, NULL }
};
@ -8949,6 +9396,7 @@ static swig_type_info _swigt__p_sentencepiece__ImmutableNBestSentencePieceText =
static swig_type_info _swigt__p_sentencepiece__ImmutableSentencePieceText = {"_p_sentencepiece__ImmutableSentencePieceText", "sentencepiece::ImmutableSentencePieceText *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece = {"_p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece", "sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__SentenceIterator = {"_p_sentencepiece__SentenceIterator", "sentencepiece::SentenceIterator *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__SentencePieceNormalizer = {"_p_sentencepiece__SentencePieceNormalizer", "sentencepiece::SentencePieceNormalizer *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__SentencePieceProcessor = {"_p_sentencepiece__SentencePieceProcessor", "sentencepiece::SentencePieceProcessor *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_sentencepiece__SentencePieceTrainer = {"_p_sentencepiece__SentencePieceTrainer", "sentencepiece::SentencePieceTrainer *", 0, 0, (void*)0, 0};
static swig_type_info _swigt__p_std__string = {"_p_std__string", "sentencepiece::util::bytes *|std::string *", 0, 0, (void*)0, 0};
@ -8965,6 +9413,7 @@ static swig_type_info *swig_type_initial[] = {
&_swigt__p_sentencepiece__ImmutableSentencePieceText,
&_swigt__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece,
&_swigt__p_sentencepiece__SentenceIterator,
&_swigt__p_sentencepiece__SentencePieceNormalizer,
&_swigt__p_sentencepiece__SentencePieceProcessor,
&_swigt__p_sentencepiece__SentencePieceTrainer,
&_swigt__p_std__string,
@ -8981,6 +9430,7 @@ static swig_cast_info _swigc__p_sentencepiece__ImmutableNBestSentencePieceText[]
static swig_cast_info _swigc__p_sentencepiece__ImmutableSentencePieceText[] = { {&_swigt__p_sentencepiece__ImmutableSentencePieceText, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece[] = { {&_swigt__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_sentencepiece__SentenceIterator[] = { {&_swigt__p_sentencepiece__SentenceIterator, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_sentencepiece__SentencePieceNormalizer[] = { {&_swigt__p_sentencepiece__SentencePieceNormalizer, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_sentencepiece__SentencePieceProcessor[] = { {&_swigt__p_sentencepiece__SentencePieceProcessor, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_sentencepiece__SentencePieceTrainer[] = { {&_swigt__p_sentencepiece__SentencePieceTrainer, 0, 0, 0},{0, 0, 0, 0}};
static swig_cast_info _swigc__p_std__string[] = { {&_swigt__p_std__string, 0, 0, 0},{0, 0, 0, 0}};
@ -8997,6 +9447,7 @@ static swig_cast_info *swig_cast_initial[] = {
_swigc__p_sentencepiece__ImmutableSentencePieceText,
_swigc__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece,
_swigc__p_sentencepiece__SentenceIterator,
_swigc__p_sentencepiece__SentencePieceNormalizer,
_swigc__p_sentencepiece__SentencePieceProcessor,
_swigc__p_sentencepiece__SentencePieceTrainer,
_swigc__p_std__string,

View File

@ -790,6 +790,64 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual('▁平成', x[1][0])
self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0, 3], x[1][1])
def test_normalizer(self):
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model')
)
self.assertEqual('KADOKAWAABC', sp.normalize('ABC'))
self.assertEqual('KADOKAWAABC', sp.Normalize('ABC'))
x = sp.Normalize('ABC', with_offsets=True)
self.assertEqual('KADOKAWAABC', x[0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
self.assertEqual(
['KADOKAWAABC', '平成'], sp.normalize(['ABC', ''])
)
self.assertEqual(
['KADOKAWAABC', '平成'], sp.Normalize(['ABC', ''])
)
x = sp.Normalize(['ABC', ''], with_offsets=True)
self.assertEqual(len(x), 2)
self.assertEqual('KADOKAWAABC', x[0][0])
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
self.assertEqual('平成', x[1][0])
self.assertEqual([0, 0, 0, 0, 0, 0, 3], x[1][1])
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model'),
add_dummy_prefix=True,
escape_whitespaces=True,
remove_extra_whitespaces=False,
)
self.assertEqual('▁hello▁▁world', sp.normalize('hello world'))
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model'),
add_dummy_prefix=True,
escape_whitespaces=True,
remove_extra_whitespaces=True,
)
self.assertEqual('▁hello▁world', sp.normalize(' hello world '))
sp = spm.SentencePieceNormalizer(
model_file=os.path.join('test', 'test_model.model'),
add_dummy_prefix=False,
escape_whitespaces=False,
remove_extra_whitespaces=True,
)
self.assertEqual('hello world', sp.normalize(' hello world '))
def test_normalizer_rule(self):
sp = spm.SentencePieceNormalizer(rule_name='identity')
self.assertEqual('', sp.Normalize(''))
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf')
self.assertEqual('abc', sp.Normalize(''))
def suite():
suite = unittest.TestSuite()

View File

@ -295,4 +295,75 @@ SentencePieceTrainer::GetPretokenizerForTraining() {
return g_pretokenizer;
}
SentencePieceNormalizer::SentencePieceNormalizer() {}
SentencePieceNormalizer::~SentencePieceNormalizer() {}
util::Status SentencePieceNormalizer::Load(
std::unique_ptr<ModelProto> model_proto) {
model_proto_ = std::move(model_proto);
normalizer_ =
std::make_unique<normalizer::Normalizer>(model_proto_->normalizer_spec());
CHECK_OR_RETURN(normalizer_);
return normalizer_->status();
}
util::Status SentencePieceNormalizer::Load(absl::string_view filename) {
auto model_proto = std::make_unique<ModelProto>();
RETURN_IF_ERROR(io::LoadModelProto(filename, model_proto.get()));
return Load(std::move(model_proto));
}
util::Status SentencePieceNormalizer::LoadFromSerializedProto(
absl::string_view serialized) {
auto model_proto = std::make_unique<ModelProto>();
CHECK_OR_RETURN(
model_proto->ParseFromArray(serialized.data(), serialized.size()));
return Load(std::move(model_proto));
}
util::Status SentencePieceNormalizer::LoadFromRuleTSV(
absl::string_view filename) {
auto model_proto = std::make_unique<ModelProto>();
auto *spec = model_proto->mutable_normalizer_spec();
spec->set_normalization_rule_tsv(std::string(filename));
RETURN_IF_ERROR(SentencePieceTrainer::PopulateNormalizerSpec(spec));
return Load(std::move(model_proto));
}
util::Status SentencePieceNormalizer::LoadFromRuleName(absl::string_view name) {
auto model_proto = std::make_unique<ModelProto>();
auto *spec = model_proto->mutable_normalizer_spec();
spec->set_name(std::string(name));
RETURN_IF_ERROR(SentencePieceTrainer::PopulateNormalizerSpec(spec));
return Load(std::move(model_proto));
}
util::Status SentencePieceNormalizer::Normalize(absl::string_view input,
std::string *normalized) const {
CHECK_OR_RETURN(normalizer_);
std::vector<size_t> norm_to_orig;
return normalizer_->Normalize(input, normalized, &norm_to_orig);
}
util::Status SentencePieceNormalizer::Normalize(
absl::string_view input, std::string *normalized,
std::vector<size_t> *norm_to_orig) const {
CHECK_OR_RETURN(normalizer_);
return normalizer_->Normalize(input, normalized, norm_to_orig);
}
std::string SentencePieceNormalizer::Normalize(absl::string_view input) const {
std::string normalized;
Normalize(input, &normalized).IgnoreError();
return normalized;
}
NormalizerSpec *SentencePieceNormalizer::mutable_normalizer_spec() const {
return model_proto_ ? model_proto_->mutable_normalizer_spec() : nullptr;
}
std::string SentencePieceNormalizer::serialized_model_proto() const {
return model_proto_ ? model_proto_->SerializeAsString() : "";
}
} // namespace sentencepiece

View File

@ -17,6 +17,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "sentencepiece_processor.h"
@ -24,11 +25,16 @@ namespace sentencepiece {
class TrainerSpec;
class NormalizerSpec;
class ModelProto;
namespace pretokenizer {
class PretokenizerForTrainingInterface;
} // namespace pretokenizer
namespace normalizer {
class Normalizer;
} // namespace normalizer
// Iterator over the training sentences.
// Training sentences are loaded sequentially as follows:
//
@ -158,6 +164,39 @@ class SentencePieceTrainer {
~SentencePieceTrainer() {}
};
class SentencePieceNormalizer {
public:
SentencePieceNormalizer();
virtual ~SentencePieceNormalizer();
virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
virtual util::Status Load(absl::string_view filename);
virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
virtual util::Status LoadFromRuleTSV(absl::string_view filename);
virtual util::Status LoadFromRuleName(absl::string_view name);
virtual util::Status Normalize(absl::string_view input,
std::string *normalized) const;
virtual util::Status Normalize(absl::string_view input,
std::string *normalized,
std::vector<size_t> *norm_to_orig) const;
virtual std::string Normalize(absl::string_view input) const;
virtual NormalizerSpec *mutable_normalizer_spec() const;
virtual std::string serialized_model_proto() const;
private:
std::unique_ptr<normalizer::Normalizer> normalizer_;
std::unique_ptr<ModelProto> model_proto_;
};
} // namespace sentencepiece
#endif // SENTENCEPIECE_TRAINER_H_

View File

@ -364,5 +364,94 @@ TEST(SentencePieceTrainerTest, PopulateModelTypeFromStringTest) {
SentencePieceTrainer::PopulateModelTypeFromString("", &spec).ok());
}
TEST(SentencePieceTrainerTest, NormalizationTest) {
const auto model_prefix =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
const auto model_file = absl::StrCat(model_prefix, ".model");
TrainerSpec trainer_spec;
trainer_spec.add_input(
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData));
trainer_spec.set_model_prefix(model_prefix);
trainer_spec.set_vocab_size(1000);
ASSERT_TRUE(SentencePieceTrainer::Train(trainer_spec).ok());
{
SentencePieceProcessor sp;
EXPECT_OK(sp.Load(model_file));
EXPECT_EQ(sp.Normalize(" ABC "), "▁KADOKAWA▁ABC");
std::string normalized;
std::vector<size_t> offsets;
EXPECT_OK(sp.Normalize(" ABC ", &normalized, &offsets));
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
24, 24, 24, 27, 28, 29, 30}));
}
{
SentencePieceNormalizer sp;
EXPECT_OK(sp.Load(model_file));
EXPECT_EQ(sp.Normalize(" ABC "), "▁KADOKAWA▁ABC");
std::string normalized;
std::vector<size_t> offsets;
EXPECT_OK(sp.Normalize(" ABC ", &normalized, &offsets));
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
24, 24, 24, 27, 28, 29, 30}));
}
auto set_normalization_only = [](SentencePieceNormalizer *normalizer) {
SentencePieceTrainer::SetProtoField("add_dummy_prefix", "false",
normalizer->mutable_normalizer_spec());
SentencePieceTrainer::SetProtoField("escape_whitespaces", "false",
normalizer->mutable_normalizer_spec());
SentencePieceTrainer::SetProtoField("remove_extra_whitespaces", "false",
normalizer->mutable_normalizer_spec());
};
{
SentencePieceNormalizer sp;
EXPECT_OK(sp.Load(model_file));
set_normalization_only(&sp);
EXPECT_EQ(sp.Normalize(" ABC "), "KADOKAWA ABC ");
}
{
SentencePieceNormalizer sp;
EXPECT_OK(sp.LoadFromRuleTSV(
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), "nfkc_cf.tsv")));
set_normalization_only(&sp);
EXPECT_EQ(sp.Normalize("ABCD"), "abcd");
}
{
SentencePieceNormalizer sp;
EXPECT_FALSE(sp.LoadFromRuleTSV("__unknown__").ok());
}
{
SentencePieceNormalizer sp;
EXPECT_OK(sp.LoadFromRuleName("nfkc_cf"));
set_normalization_only(&sp);
EXPECT_EQ(sp.Normalize("ABCD"), "abcd");
}
{
SentencePieceNormalizer sp;
EXPECT_OK(sp.LoadFromRuleName("identity"));
set_normalization_only(&sp);
EXPECT_EQ(sp.Normalize(""), "");
}
{
SentencePieceNormalizer sp;
EXPECT_FALSE(sp.LoadFromRuleName("__unknown__").ok());
}
}
} // namespace
} // namespace sentencepiece