mirror of
https://github.com/google/sentencepiece.git
synced 2024-12-29 11:11:58 +03:00
add more advanced SentencePieceNormalizer class
This commit is contained in:
parent
f5c736302c
commit
ed76ecc478
@ -1004,6 +1004,98 @@ class SentencePieceTrainer(object):
|
||||
|
||||
# Register SentencePieceTrainer in _sentencepiece:
|
||||
_sentencepiece.SentencePieceTrainer_swigregister(SentencePieceTrainer)
|
||||
class SentencePieceNormalizer(object):
|
||||
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
|
||||
__repr__ = _swig_repr
|
||||
|
||||
def __init__(self):
|
||||
_sentencepiece.SentencePieceNormalizer_swiginit(self, _sentencepiece.new_SentencePieceNormalizer())
|
||||
__swig_destroy__ = _sentencepiece.delete_SentencePieceNormalizer
|
||||
|
||||
def LoadFromSerializedProto(self, serialized):
|
||||
return _sentencepiece.SentencePieceNormalizer_LoadFromSerializedProto(self, serialized)
|
||||
|
||||
def LoadFromRuleTSV(self, filename):
|
||||
return _sentencepiece.SentencePieceNormalizer_LoadFromRuleTSV(self, filename)
|
||||
|
||||
def LoadFromRuleName(self, name):
|
||||
return _sentencepiece.SentencePieceNormalizer_LoadFromRuleName(self, name)
|
||||
|
||||
def serialized_model_proto(self):
|
||||
return _sentencepiece.SentencePieceNormalizer_serialized_model_proto(self)
|
||||
|
||||
def LoadFromFile(self, arg):
|
||||
return _sentencepiece.SentencePieceNormalizer_LoadFromFile(self, arg)
|
||||
|
||||
def _Normalize(self, text):
|
||||
return _sentencepiece.SentencePieceNormalizer__Normalize(self, text)
|
||||
|
||||
def _NormalizeWithOffsets(self, text):
|
||||
return _sentencepiece.SentencePieceNormalizer__NormalizeWithOffsets(self, text)
|
||||
|
||||
def _SetProtoField(self, name, value):
|
||||
return _sentencepiece.SentencePieceNormalizer__SetProtoField(self, name, value)
|
||||
|
||||
def Init(self,
|
||||
model_file=None,
|
||||
model_proto=None,
|
||||
rule_tsv=None,
|
||||
rule_name=None,
|
||||
add_dummy_prefix=False,
|
||||
escape_whitespaces=False,
|
||||
remove_extra_whitespaces=False):
|
||||
"""Initialzie sentencePieceNormalizer.
|
||||
|
||||
Args:
|
||||
model_file: The sentencepiece model file path.
|
||||
model_proto: The sentencepiece model serialized proto.
|
||||
rule_tsv: The normalization rule file in TSV format.
|
||||
rule_name: Pre-defined normalization name.
|
||||
add_dummy_prefix: add dummy prefix.
|
||||
escape_whitespaces: escape whitespaces.
|
||||
remove_extra_whitespaces: remove extra whitespaces.
|
||||
"""
|
||||
|
||||
_sentencepiece_normalizer_init_native(self)
|
||||
|
||||
if model_file:
|
||||
status = self.LoadFromFile(model_file)
|
||||
elif model_proto:
|
||||
status = self.LoadFromSerializedProto(model_proto)
|
||||
elif rule_tsv:
|
||||
status = self.LoadFromRuleTSV(rule_tsv)
|
||||
elif rule_name:
|
||||
status = self.LoadFromRuleName(rule_name)
|
||||
else:
|
||||
raise RuntimeError('no model is specified')
|
||||
|
||||
if status:
|
||||
self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
|
||||
self._SetProtoField('escape_whitespaces', escape_whitespaces)
|
||||
self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)
|
||||
|
||||
def Normalize(self, input, with_offsets=None):
|
||||
def _normalize(text):
|
||||
if with_offsets:
|
||||
return self._NormalizeWithOffsets(text)
|
||||
return self._Normalize(text)
|
||||
|
||||
if type(input) is list:
|
||||
return [_normalize(x) for x in input]
|
||||
return _normalize(input)
|
||||
|
||||
|
||||
def __getstate__(self):
|
||||
return self.serialized_model_proto()
|
||||
|
||||
|
||||
def __setstate__(self, serialized_model_proto):
|
||||
self.__init__()
|
||||
self.LoadFromSerializedProto(serialized_model_proto)
|
||||
|
||||
|
||||
# Register SentencePieceNormalizer in _sentencepiece:
|
||||
_sentencepiece.SentencePieceNormalizer_swigregister(SentencePieceNormalizer)
|
||||
|
||||
|
||||
import re
|
||||
@ -1045,7 +1137,9 @@ def _batchnize(classname, name):
|
||||
|
||||
|
||||
_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
|
||||
_sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
|
||||
setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
|
||||
setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)
|
||||
|
||||
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
|
||||
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
|
||||
@ -1058,6 +1152,7 @@ for m in [
|
||||
|
||||
_add_snake_case(SentencePieceProcessor)
|
||||
_add_snake_case(SentencePieceTrainer)
|
||||
_add_snake_case(SentencePieceNormalizer)
|
||||
set_random_generator_seed = SetRandomGeneratorSeed
|
||||
set_min_log_level = SetMinLogLevel
|
||||
|
||||
|
@ -368,6 +368,10 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
|
||||
%ignore sentencepiece::SentencePieceTrainer::SetPretokenizerForTraining;
|
||||
%ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining;
|
||||
|
||||
%ignore sentencepiece::SentencePieceNormalizer::Load;
|
||||
%ignore sentencepiece::SentencePieceNormalizer::Normalize;
|
||||
%ignore sentencepiece::SentencePieceNormalizer::mutable_normalizer_spec;
|
||||
|
||||
%ignore sentencepiece::io::LoadModelProto;
|
||||
%ignore sentencepiece::io::SaveModelProto;
|
||||
|
||||
@ -1293,6 +1297,92 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
|
||||
}
|
||||
}
|
||||
|
||||
%extend sentencepiece::SentencePieceNormalizer {
|
||||
sentencepiece::util::Status LoadFromFile(absl::string_view arg) {
|
||||
return $self->Load(arg);
|
||||
}
|
||||
|
||||
std::string _Normalize(absl::string_view text) {
|
||||
std::string result;
|
||||
const auto _status = $self->Normalize(text, &result);
|
||||
if (!_status.ok()) throw _status;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<size_t>> _NormalizeWithOffsets(absl::string_view text) {
|
||||
std::pair<std::string, std::vector<size_t>> result;
|
||||
const auto _status = $self->Normalize(text, &result.first, &result.second);
|
||||
if (!_status.ok()) throw _status;
|
||||
return result;
|
||||
}
|
||||
|
||||
void _SetProtoField(absl::string_view name, bool value) {
|
||||
sentencepiece::SentencePieceTrainer::SetProtoField(
|
||||
name,
|
||||
value ? "1" : "0",
|
||||
$self->mutable_normalizer_spec()).IgnoreError();
|
||||
}
|
||||
|
||||
%pythoncode %{
|
||||
def Init(self,
|
||||
model_file=None,
|
||||
model_proto=None,
|
||||
rule_tsv=None,
|
||||
rule_name=None,
|
||||
add_dummy_prefix=False,
|
||||
escape_whitespaces=False,
|
||||
remove_extra_whitespaces=False):
|
||||
"""Initialzie sentencePieceNormalizer.
|
||||
|
||||
Args:
|
||||
model_file: The sentencepiece model file path.
|
||||
model_proto: The sentencepiece model serialized proto.
|
||||
rule_tsv: The normalization rule file in TSV format.
|
||||
rule_name: Pre-defined normalization name.
|
||||
add_dummy_prefix: add dummy prefix.
|
||||
escape_whitespaces: escape whitespaces.
|
||||
remove_extra_whitespaces: remove extra whitespaces.
|
||||
"""
|
||||
|
||||
_sentencepiece_normalizer_init_native(self)
|
||||
|
||||
if model_file:
|
||||
status = self.LoadFromFile(model_file)
|
||||
elif model_proto:
|
||||
status = self.LoadFromSerializedProto(model_proto)
|
||||
elif rule_tsv:
|
||||
status = self.LoadFromRuleTSV(rule_tsv)
|
||||
elif rule_name:
|
||||
status = self.LoadFromRuleName(rule_name)
|
||||
else:
|
||||
raise RuntimeError('no model is specified')
|
||||
|
||||
if status:
|
||||
self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
|
||||
self._SetProtoField('escape_whitespaces', escape_whitespaces)
|
||||
self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)
|
||||
|
||||
def Normalize(self, input, with_offsets=None):
|
||||
def _normalize(text):
|
||||
if with_offsets:
|
||||
return self._NormalizeWithOffsets(text)
|
||||
return self._Normalize(text)
|
||||
|
||||
if type(input) is list:
|
||||
return [_normalize(x) for x in input]
|
||||
return _normalize(input)
|
||||
|
||||
|
||||
def __getstate__(self):
|
||||
return self.serialized_model_proto()
|
||||
|
||||
|
||||
def __setstate__(self, serialized_model_proto):
|
||||
self.__init__()
|
||||
self.LoadFromSerializedProto(serialized_model_proto)
|
||||
%}
|
||||
}
|
||||
|
||||
%extend sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece {
|
||||
%rename(_piece) piece;
|
||||
%rename(_id) id;
|
||||
@ -1790,7 +1880,9 @@ def _batchnize(classname, name):
|
||||
|
||||
|
||||
_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
|
||||
_sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
|
||||
setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
|
||||
setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)
|
||||
|
||||
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
|
||||
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
|
||||
@ -1803,6 +1895,7 @@ for m in [
|
||||
|
||||
_add_snake_case(SentencePieceProcessor)
|
||||
_add_snake_case(SentencePieceTrainer)
|
||||
_add_snake_case(SentencePieceNormalizer)
|
||||
set_random_generator_seed = SetRandomGeneratorSeed
|
||||
set_min_log_level = SetMinLogLevel
|
||||
|
||||
|
@ -2987,16 +2987,17 @@ SWIG_Python_NonDynamicSetAttr(PyObject *obj, PyObject *name, PyObject *value) {
|
||||
#define SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText swig_types[3]
|
||||
#define SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece swig_types[4]
|
||||
#define SWIGTYPE_p_sentencepiece__SentenceIterator swig_types[5]
|
||||
#define SWIGTYPE_p_sentencepiece__SentencePieceProcessor swig_types[6]
|
||||
#define SWIGTYPE_p_sentencepiece__SentencePieceTrainer swig_types[7]
|
||||
#define SWIGTYPE_p_std__string swig_types[8]
|
||||
#define SWIGTYPE_p_std__unordered_mapT_std__string_std__string_t swig_types[9]
|
||||
#define SWIGTYPE_p_std__vectorT_absl__string_view_t swig_types[10]
|
||||
#define SWIGTYPE_p_std__vectorT_int_t swig_types[11]
|
||||
#define SWIGTYPE_p_std__vectorT_std__vectorT_absl__string_view_t_t swig_types[12]
|
||||
#define SWIGTYPE_p_std__vectorT_std__vectorT_int_t_t swig_types[13]
|
||||
static swig_type_info *swig_types[15];
|
||||
static swig_module_info swig_module = {swig_types, 14, 0, 0, 0, 0};
|
||||
#define SWIGTYPE_p_sentencepiece__SentencePieceNormalizer swig_types[6]
|
||||
#define SWIGTYPE_p_sentencepiece__SentencePieceProcessor swig_types[7]
|
||||
#define SWIGTYPE_p_sentencepiece__SentencePieceTrainer swig_types[8]
|
||||
#define SWIGTYPE_p_std__string swig_types[9]
|
||||
#define SWIGTYPE_p_std__unordered_mapT_std__string_std__string_t swig_types[10]
|
||||
#define SWIGTYPE_p_std__vectorT_absl__string_view_t swig_types[11]
|
||||
#define SWIGTYPE_p_std__vectorT_int_t swig_types[12]
|
||||
#define SWIGTYPE_p_std__vectorT_std__vectorT_absl__string_view_t_t swig_types[13]
|
||||
#define SWIGTYPE_p_std__vectorT_std__vectorT_int_t_t swig_types[14]
|
||||
static swig_type_info *swig_types[16];
|
||||
static swig_module_info swig_module = {swig_types, 15, 0, 0, 0, 0};
|
||||
#define SWIG_TypeQuery(name) SWIG_TypeQueryModule(&swig_module, &swig_module, name)
|
||||
#define SWIG_MangledTypeQuery(name) SWIG_MangledTypeQueryModule(&swig_module, &swig_module, name)
|
||||
|
||||
@ -4123,6 +4124,27 @@ SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceTrainer__TrainF
|
||||
if (!_status.ok()) throw _status;
|
||||
return model_proto;
|
||||
}
|
||||
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceNormalizer_LoadFromFile(sentencepiece::SentencePieceNormalizer *self,absl::string_view arg){
|
||||
return self->Load(arg);
|
||||
}
|
||||
SWIGINTERN std::string sentencepiece_SentencePieceNormalizer__Normalize(sentencepiece::SentencePieceNormalizer *self,absl::string_view text){
|
||||
std::string result;
|
||||
const auto _status = self->Normalize(text, &result);
|
||||
if (!_status.ok()) throw _status;
|
||||
return result;
|
||||
}
|
||||
SWIGINTERN std::pair< std::string,std::vector< size_t > > sentencepiece_SentencePieceNormalizer__NormalizeWithOffsets(sentencepiece::SentencePieceNormalizer *self,absl::string_view text){
|
||||
std::pair<std::string, std::vector<size_t>> result;
|
||||
const auto _status = self->Normalize(text, &result.first, &result.second);
|
||||
if (!_status.ok()) throw _status;
|
||||
return result;
|
||||
}
|
||||
SWIGINTERN void sentencepiece_SentencePieceNormalizer__SetProtoField(sentencepiece::SentencePieceNormalizer *self,absl::string_view name,bool value){
|
||||
sentencepiece::SentencePieceTrainer::SetProtoField(
|
||||
name,
|
||||
value ? "1" : "0",
|
||||
self->mutable_normalizer_spec()).IgnoreError();
|
||||
}
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@ -8846,6 +8868,419 @@ SWIGINTERN PyObject *SentencePieceTrainer_swigregister(PyObject *SWIGUNUSEDPARM(
|
||||
return SWIG_Py_Void();
|
||||
}
|
||||
|
||||
SWIGINTERN PyObject *_wrap_new_SentencePieceNormalizer(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *result = 0 ;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "new_SentencePieceNormalizer", 0, 0, 0)) SWIG_fail;
|
||||
{
|
||||
try {
|
||||
result = (sentencepiece::SentencePieceNormalizer *)new sentencepiece::SentencePieceNormalizer();
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
resultobj = SWIG_NewPointerObj(SWIG_as_voidptr(result), SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, SWIG_POINTER_NEW | 0 );
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_delete_SentencePieceNormalizer(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[1] ;
|
||||
|
||||
if (!args) SWIG_fail;
|
||||
swig_obj[0] = args;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, SWIG_POINTER_DISOWN | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "delete_SentencePieceNormalizer" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
try {
|
||||
delete arg1;
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
resultobj = SWIG_Py_Void();
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromSerializedProto(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
absl::string_view arg2 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
sentencepiece::util::Status result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromSerializedProto", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
const PyInputString ustring(swig_obj[1]);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
arg2 = ustring.str();
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = (arg1)->LoadFromSerializedProto(arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
if (!(&result)->ok()) {
|
||||
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
|
||||
}
|
||||
resultobj = SWIG_From_bool((&result)->ok());
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromRuleTSV(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
absl::string_view arg2 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
sentencepiece::util::Status result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromRuleTSV", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromRuleTSV" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
const PyInputString ustring(swig_obj[1]);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
arg2 = ustring.str();
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = (arg1)->LoadFromRuleTSV(arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
if (!(&result)->ok()) {
|
||||
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
|
||||
}
|
||||
resultobj = SWIG_From_bool((&result)->ok());
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromRuleName(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
absl::string_view arg2 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
sentencepiece::util::Status result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromRuleName", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromRuleName" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
const PyInputString ustring(swig_obj[1]);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
arg2 = ustring.str();
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = (arg1)->LoadFromRuleName(arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
if (!(&result)->ok()) {
|
||||
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
|
||||
}
|
||||
resultobj = SWIG_From_bool((&result)->ok());
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_serialized_model_proto(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[1] ;
|
||||
std::string result;
|
||||
|
||||
if (!args) SWIG_fail;
|
||||
swig_obj[0] = args;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_serialized_model_proto" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer const *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
try {
|
||||
result = ((sentencepiece::SentencePieceNormalizer const *)arg1)->serialized_model_proto();
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
PyObject *input_type = resultobj;
|
||||
resultobj = MakePyOutputString(result, input_type);
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer_LoadFromFile(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
absl::string_view arg2 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
sentencepiece::util::Status result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer_LoadFromFile", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer_LoadFromFile" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
const PyInputString ustring(swig_obj[1]);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
arg2 = ustring.str();
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = sentencepiece_SentencePieceNormalizer_LoadFromFile(arg1,SWIG_STD_MOVE(arg2));
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
if (!(&result)->ok()) {
|
||||
SWIG_exception(ToSwigError((&result)->code()), (&result)->ToString().c_str());
|
||||
}
|
||||
resultobj = SWIG_From_bool((&result)->ok());
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__Normalize(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
absl::string_view arg2 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
std::string result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer__Normalize", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer__Normalize" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
const PyInputString ustring(swig_obj[1]);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
arg2 = ustring.str();
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = sentencepiece_SentencePieceNormalizer__Normalize(arg1,SWIG_STD_MOVE(arg2));
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
PyObject *input_type = resultobj;
|
||||
resultobj = MakePyOutputString(result, input_type);
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__NormalizeWithOffsets(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
absl::string_view arg2 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
std::pair< std::string,std::vector< size_t > > result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer__NormalizeWithOffsets", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer__NormalizeWithOffsets" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
const PyInputString ustring(swig_obj[1]);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
arg2 = ustring.str();
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = sentencepiece_SentencePieceNormalizer__NormalizeWithOffsets(arg1,SWIG_STD_MOVE(arg2));
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
PyObject *input_type = resultobj;
|
||||
PyObject *obj = PyList_New((&result)->second.size());
|
||||
for (size_t i = 0; i < (&result)->second.size(); ++i) {
|
||||
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>((&result)->second[i])));
|
||||
}
|
||||
resultobj = PyTuple_Pack(2, MakePyOutputString((&result)->first, input_type), obj);
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceNormalizer__SetProtoField(PyObject *self, PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceNormalizer *arg1 = (sentencepiece::SentencePieceNormalizer *) 0 ;
|
||||
absl::string_view arg2 ;
|
||||
bool arg3 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
bool val3 ;
|
||||
int ecode3 = 0 ;
|
||||
PyObject *swig_obj[3] ;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceNormalizer__SetProtoField", 3, 3, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceNormalizer__SetProtoField" "', argument " "1"" of type '" "sentencepiece::SentencePieceNormalizer *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceNormalizer * >(argp1);
|
||||
{
|
||||
const PyInputString ustring(swig_obj[1]);
|
||||
if (!ustring.IsAvalable()) {
|
||||
PyErr_SetString(PyExc_TypeError, "not a string");
|
||||
SWIG_fail;
|
||||
}
|
||||
resultobj = ustring.input_type();
|
||||
arg2 = ustring.str();
|
||||
}
|
||||
ecode3 = SWIG_AsVal_bool(swig_obj[2], &val3);
|
||||
if (!SWIG_IsOK(ecode3)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(ecode3), "in method '" "SentencePieceNormalizer__SetProtoField" "', argument " "3"" of type '" "bool""'");
|
||||
}
|
||||
arg3 = static_cast< bool >(val3);
|
||||
{
|
||||
try {
|
||||
sentencepiece_SentencePieceNormalizer__SetProtoField(arg1,SWIG_STD_MOVE(arg2),arg3);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
resultobj = SWIG_Py_Void();
|
||||
return resultobj;
|
||||
fail:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *SentencePieceNormalizer_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *obj;
|
||||
if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL;
|
||||
SWIG_TypeNewClientData(SWIGTYPE_p_sentencepiece__SentencePieceNormalizer, SWIG_NewClientData(obj));
|
||||
return SWIG_Py_Void();
|
||||
}
|
||||
|
||||
SWIGINTERN PyObject *SentencePieceNormalizer_swiginit(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
return SWIG_Python_InitShadowInstance(args);
|
||||
}
|
||||
|
||||
static PyMethodDef SwigMethods[] = {
|
||||
{ "new_ImmutableSentencePieceText_ImmutableSentencePiece", _wrap_new_ImmutableSentencePieceText_ImmutableSentencePiece, METH_NOARGS, NULL},
|
||||
{ "delete_ImmutableSentencePieceText_ImmutableSentencePiece", _wrap_delete_ImmutableSentencePieceText_ImmutableSentencePiece, METH_O, NULL},
|
||||
@ -8937,6 +9372,18 @@ static PyMethodDef SwigMethods[] = {
|
||||
{ "SentencePieceTrainer__TrainFromMap3", _wrap_SentencePieceTrainer__TrainFromMap3, METH_O, NULL},
|
||||
{ "SentencePieceTrainer__TrainFromMap4", _wrap_SentencePieceTrainer__TrainFromMap4, METH_VARARGS, NULL},
|
||||
{ "SentencePieceTrainer_swigregister", SentencePieceTrainer_swigregister, METH_O, NULL},
|
||||
{ "new_SentencePieceNormalizer", _wrap_new_SentencePieceNormalizer, METH_NOARGS, NULL},
|
||||
{ "delete_SentencePieceNormalizer", _wrap_delete_SentencePieceNormalizer, METH_O, NULL},
|
||||
{ "SentencePieceNormalizer_LoadFromSerializedProto", _wrap_SentencePieceNormalizer_LoadFromSerializedProto, METH_VARARGS, NULL},
|
||||
{ "SentencePieceNormalizer_LoadFromRuleTSV", _wrap_SentencePieceNormalizer_LoadFromRuleTSV, METH_VARARGS, NULL},
|
||||
{ "SentencePieceNormalizer_LoadFromRuleName", _wrap_SentencePieceNormalizer_LoadFromRuleName, METH_VARARGS, NULL},
|
||||
{ "SentencePieceNormalizer_serialized_model_proto", _wrap_SentencePieceNormalizer_serialized_model_proto, METH_O, NULL},
|
||||
{ "SentencePieceNormalizer_LoadFromFile", _wrap_SentencePieceNormalizer_LoadFromFile, METH_VARARGS, NULL},
|
||||
{ "SentencePieceNormalizer__Normalize", _wrap_SentencePieceNormalizer__Normalize, METH_VARARGS, NULL},
|
||||
{ "SentencePieceNormalizer__NormalizeWithOffsets", _wrap_SentencePieceNormalizer__NormalizeWithOffsets, METH_VARARGS, NULL},
|
||||
{ "SentencePieceNormalizer__SetProtoField", _wrap_SentencePieceNormalizer__SetProtoField, METH_VARARGS, NULL},
|
||||
{ "SentencePieceNormalizer_swigregister", SentencePieceNormalizer_swigregister, METH_O, NULL},
|
||||
{ "SentencePieceNormalizer_swiginit", SentencePieceNormalizer_swiginit, METH_VARARGS, NULL},
|
||||
{ NULL, NULL, 0, NULL }
|
||||
};
|
||||
|
||||
@ -8949,6 +9396,7 @@ static swig_type_info _swigt__p_sentencepiece__ImmutableNBestSentencePieceText =
|
||||
static swig_type_info _swigt__p_sentencepiece__ImmutableSentencePieceText = {"_p_sentencepiece__ImmutableSentencePieceText", "sentencepiece::ImmutableSentencePieceText *", 0, 0, (void*)0, 0};
|
||||
static swig_type_info _swigt__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece = {"_p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece", "sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece *", 0, 0, (void*)0, 0};
|
||||
static swig_type_info _swigt__p_sentencepiece__SentenceIterator = {"_p_sentencepiece__SentenceIterator", "sentencepiece::SentenceIterator *", 0, 0, (void*)0, 0};
|
||||
static swig_type_info _swigt__p_sentencepiece__SentencePieceNormalizer = {"_p_sentencepiece__SentencePieceNormalizer", "sentencepiece::SentencePieceNormalizer *", 0, 0, (void*)0, 0};
|
||||
static swig_type_info _swigt__p_sentencepiece__SentencePieceProcessor = {"_p_sentencepiece__SentencePieceProcessor", "sentencepiece::SentencePieceProcessor *", 0, 0, (void*)0, 0};
|
||||
static swig_type_info _swigt__p_sentencepiece__SentencePieceTrainer = {"_p_sentencepiece__SentencePieceTrainer", "sentencepiece::SentencePieceTrainer *", 0, 0, (void*)0, 0};
|
||||
static swig_type_info _swigt__p_std__string = {"_p_std__string", "sentencepiece::util::bytes *|std::string *", 0, 0, (void*)0, 0};
|
||||
@ -8965,6 +9413,7 @@ static swig_type_info *swig_type_initial[] = {
|
||||
&_swigt__p_sentencepiece__ImmutableSentencePieceText,
|
||||
&_swigt__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece,
|
||||
&_swigt__p_sentencepiece__SentenceIterator,
|
||||
&_swigt__p_sentencepiece__SentencePieceNormalizer,
|
||||
&_swigt__p_sentencepiece__SentencePieceProcessor,
|
||||
&_swigt__p_sentencepiece__SentencePieceTrainer,
|
||||
&_swigt__p_std__string,
|
||||
@ -8981,6 +9430,7 @@ static swig_cast_info _swigc__p_sentencepiece__ImmutableNBestSentencePieceText[]
|
||||
static swig_cast_info _swigc__p_sentencepiece__ImmutableSentencePieceText[] = { {&_swigt__p_sentencepiece__ImmutableSentencePieceText, 0, 0, 0},{0, 0, 0, 0}};
|
||||
static swig_cast_info _swigc__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece[] = { {&_swigt__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece, 0, 0, 0},{0, 0, 0, 0}};
|
||||
static swig_cast_info _swigc__p_sentencepiece__SentenceIterator[] = { {&_swigt__p_sentencepiece__SentenceIterator, 0, 0, 0},{0, 0, 0, 0}};
|
||||
static swig_cast_info _swigc__p_sentencepiece__SentencePieceNormalizer[] = { {&_swigt__p_sentencepiece__SentencePieceNormalizer, 0, 0, 0},{0, 0, 0, 0}};
|
||||
static swig_cast_info _swigc__p_sentencepiece__SentencePieceProcessor[] = { {&_swigt__p_sentencepiece__SentencePieceProcessor, 0, 0, 0},{0, 0, 0, 0}};
|
||||
static swig_cast_info _swigc__p_sentencepiece__SentencePieceTrainer[] = { {&_swigt__p_sentencepiece__SentencePieceTrainer, 0, 0, 0},{0, 0, 0, 0}};
|
||||
static swig_cast_info _swigc__p_std__string[] = { {&_swigt__p_std__string, 0, 0, 0},{0, 0, 0, 0}};
|
||||
@ -8997,6 +9447,7 @@ static swig_cast_info *swig_cast_initial[] = {
|
||||
_swigc__p_sentencepiece__ImmutableSentencePieceText,
|
||||
_swigc__p_sentencepiece__ImmutableSentencePieceText_ImmutableSentencePiece,
|
||||
_swigc__p_sentencepiece__SentenceIterator,
|
||||
_swigc__p_sentencepiece__SentencePieceNormalizer,
|
||||
_swigc__p_sentencepiece__SentencePieceProcessor,
|
||||
_swigc__p_sentencepiece__SentencePieceTrainer,
|
||||
_swigc__p_std__string,
|
||||
|
@ -790,6 +790,64 @@ class TestSentencepieceProcessor(unittest.TestCase):
|
||||
self.assertEqual('▁平成', x[1][0])
|
||||
self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0, 3], x[1][1])
|
||||
|
||||
def test_normalizer(self):
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model')
|
||||
)
|
||||
|
||||
self.assertEqual('KADOKAWAABC', sp.normalize('KADOKAWAABC'))
|
||||
self.assertEqual('KADOKAWAABC', sp.Normalize('KADOKAWAABC'))
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC', with_offsets=True)
|
||||
self.assertEqual('KADOKAWAABC', x[0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
|
||||
|
||||
self.assertEqual(
|
||||
['KADOKAWAABC', '平成'], sp.normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
self.assertEqual(
|
||||
['KADOKAWAABC', '平成'], sp.Normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
|
||||
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('KADOKAWAABC', x[0][0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
|
||||
|
||||
self.assertEqual('平成', x[1][0])
|
||||
self.assertEqual([0, 0, 0, 0, 0, 0, 3], x[1][1])
|
||||
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
add_dummy_prefix=True,
|
||||
escape_whitespaces=True,
|
||||
remove_extra_whitespaces=False,
|
||||
)
|
||||
self.assertEqual('▁hello▁▁world', sp.normalize('hello world'))
|
||||
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
add_dummy_prefix=True,
|
||||
escape_whitespaces=True,
|
||||
remove_extra_whitespaces=True,
|
||||
)
|
||||
self.assertEqual('▁hello▁world', sp.normalize(' hello world '))
|
||||
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
add_dummy_prefix=False,
|
||||
escape_whitespaces=False,
|
||||
remove_extra_whitespaces=True,
|
||||
)
|
||||
self.assertEqual('hello world', sp.normalize(' hello world '))
|
||||
|
||||
def test_normalizer_rule(self):
|
||||
sp = spm.SentencePieceNormalizer(rule_name='identity')
|
||||
self.assertEqual('ABC', sp.Normalize('ABC'))
|
||||
|
||||
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf')
|
||||
self.assertEqual('abc', sp.Normalize('ABC'))
|
||||
|
||||
|
||||
def suite():
|
||||
suite = unittest.TestSuite()
|
||||
|
@ -295,4 +295,75 @@ SentencePieceTrainer::GetPretokenizerForTraining() {
|
||||
return g_pretokenizer;
|
||||
}
|
||||
|
||||
SentencePieceNormalizer::SentencePieceNormalizer() {}
|
||||
SentencePieceNormalizer::~SentencePieceNormalizer() {}
|
||||
|
||||
util::Status SentencePieceNormalizer::Load(
|
||||
std::unique_ptr<ModelProto> model_proto) {
|
||||
model_proto_ = std::move(model_proto);
|
||||
normalizer_ =
|
||||
std::make_unique<normalizer::Normalizer>(model_proto_->normalizer_spec());
|
||||
CHECK_OR_RETURN(normalizer_);
|
||||
return normalizer_->status();
|
||||
}
|
||||
|
||||
util::Status SentencePieceNormalizer::Load(absl::string_view filename) {
|
||||
auto model_proto = std::make_unique<ModelProto>();
|
||||
RETURN_IF_ERROR(io::LoadModelProto(filename, model_proto.get()));
|
||||
return Load(std::move(model_proto));
|
||||
}
|
||||
|
||||
util::Status SentencePieceNormalizer::LoadFromSerializedProto(
|
||||
absl::string_view serialized) {
|
||||
auto model_proto = std::make_unique<ModelProto>();
|
||||
CHECK_OR_RETURN(
|
||||
model_proto->ParseFromArray(serialized.data(), serialized.size()));
|
||||
return Load(std::move(model_proto));
|
||||
}
|
||||
|
||||
util::Status SentencePieceNormalizer::LoadFromRuleTSV(
|
||||
absl::string_view filename) {
|
||||
auto model_proto = std::make_unique<ModelProto>();
|
||||
auto *spec = model_proto->mutable_normalizer_spec();
|
||||
spec->set_normalization_rule_tsv(std::string(filename));
|
||||
RETURN_IF_ERROR(SentencePieceTrainer::PopulateNormalizerSpec(spec));
|
||||
return Load(std::move(model_proto));
|
||||
}
|
||||
|
||||
util::Status SentencePieceNormalizer::LoadFromRuleName(absl::string_view name) {
|
||||
auto model_proto = std::make_unique<ModelProto>();
|
||||
auto *spec = model_proto->mutable_normalizer_spec();
|
||||
spec->set_name(std::string(name));
|
||||
RETURN_IF_ERROR(SentencePieceTrainer::PopulateNormalizerSpec(spec));
|
||||
return Load(std::move(model_proto));
|
||||
}
|
||||
|
||||
util::Status SentencePieceNormalizer::Normalize(absl::string_view input,
|
||||
std::string *normalized) const {
|
||||
CHECK_OR_RETURN(normalizer_);
|
||||
std::vector<size_t> norm_to_orig;
|
||||
return normalizer_->Normalize(input, normalized, &norm_to_orig);
|
||||
}
|
||||
|
||||
util::Status SentencePieceNormalizer::Normalize(
|
||||
absl::string_view input, std::string *normalized,
|
||||
std::vector<size_t> *norm_to_orig) const {
|
||||
CHECK_OR_RETURN(normalizer_);
|
||||
return normalizer_->Normalize(input, normalized, norm_to_orig);
|
||||
}
|
||||
|
||||
std::string SentencePieceNormalizer::Normalize(absl::string_view input) const {
|
||||
std::string normalized;
|
||||
Normalize(input, &normalized).IgnoreError();
|
||||
return normalized;
|
||||
}
|
||||
|
||||
NormalizerSpec *SentencePieceNormalizer::mutable_normalizer_spec() const {
|
||||
return model_proto_ ? model_proto_->mutable_normalizer_spec() : nullptr;
|
||||
}
|
||||
|
||||
std::string SentencePieceNormalizer::serialized_model_proto() const {
|
||||
return model_proto_ ? model_proto_->SerializeAsString() : "";
|
||||
}
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "sentencepiece_processor.h"
|
||||
|
||||
@ -24,11 +25,16 @@ namespace sentencepiece {
|
||||
|
||||
class TrainerSpec;
|
||||
class NormalizerSpec;
|
||||
class ModelProto;
|
||||
|
||||
namespace pretokenizer {
|
||||
class PretokenizerForTrainingInterface;
|
||||
} // namespace pretokenizer
|
||||
|
||||
namespace normalizer {
|
||||
class Normalizer;
|
||||
} // namespace normalizer
|
||||
|
||||
// Iterator over the training sentences.
|
||||
// Training sentences are loaded sequentially as follows:
|
||||
//
|
||||
@ -158,6 +164,39 @@ class SentencePieceTrainer {
|
||||
~SentencePieceTrainer() {}
|
||||
};
|
||||
|
||||
class SentencePieceNormalizer {
|
||||
public:
|
||||
SentencePieceNormalizer();
|
||||
virtual ~SentencePieceNormalizer();
|
||||
|
||||
virtual util::Status Load(std::unique_ptr<ModelProto> model_proto);
|
||||
|
||||
virtual util::Status Load(absl::string_view filename);
|
||||
|
||||
virtual util::Status LoadFromSerializedProto(absl::string_view serialized);
|
||||
|
||||
virtual util::Status LoadFromRuleTSV(absl::string_view filename);
|
||||
|
||||
virtual util::Status LoadFromRuleName(absl::string_view name);
|
||||
|
||||
virtual util::Status Normalize(absl::string_view input,
|
||||
std::string *normalized) const;
|
||||
|
||||
virtual util::Status Normalize(absl::string_view input,
|
||||
std::string *normalized,
|
||||
std::vector<size_t> *norm_to_orig) const;
|
||||
|
||||
virtual std::string Normalize(absl::string_view input) const;
|
||||
|
||||
virtual NormalizerSpec *mutable_normalizer_spec() const;
|
||||
|
||||
virtual std::string serialized_model_proto() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<normalizer::Normalizer> normalizer_;
|
||||
std::unique_ptr<ModelProto> model_proto_;
|
||||
};
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
||||
#endif // SENTENCEPIECE_TRAINER_H_
|
||||
|
@ -364,5 +364,94 @@ TEST(SentencePieceTrainerTest, PopulateModelTypeFromStringTest) {
|
||||
SentencePieceTrainer::PopulateModelTypeFromString("", &spec).ok());
|
||||
}
|
||||
|
||||
TEST(SentencePieceTrainerTest, NormalizationTest) {
|
||||
const auto model_prefix =
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
|
||||
const auto model_file = absl::StrCat(model_prefix, ".model");
|
||||
|
||||
TrainerSpec trainer_spec;
|
||||
trainer_spec.add_input(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData));
|
||||
trainer_spec.set_model_prefix(model_prefix);
|
||||
trainer_spec.set_vocab_size(1000);
|
||||
ASSERT_TRUE(SentencePieceTrainer::Train(trainer_spec).ok());
|
||||
|
||||
{
|
||||
SentencePieceProcessor sp;
|
||||
EXPECT_OK(sp.Load(model_file));
|
||||
EXPECT_EQ(sp.Normalize("KADOKAWA ABC "), "▁KADOKAWA▁ABC");
|
||||
|
||||
std::string normalized;
|
||||
std::vector<size_t> offsets;
|
||||
|
||||
EXPECT_OK(sp.Normalize("KADOKAWA ABC ", &normalized, &offsets));
|
||||
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
|
||||
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
|
||||
24, 24, 24, 27, 28, 29, 30}));
|
||||
}
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_OK(sp.Load(model_file));
|
||||
EXPECT_EQ(sp.Normalize("KADOKAWA ABC "), "▁KADOKAWA▁ABC");
|
||||
|
||||
std::string normalized;
|
||||
std::vector<size_t> offsets;
|
||||
|
||||
EXPECT_OK(sp.Normalize("KADOKAWA ABC ", &normalized, &offsets));
|
||||
EXPECT_EQ(normalized, "▁KADOKAWA▁ABC");
|
||||
EXPECT_EQ(offsets, std::vector<size_t>({0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21,
|
||||
24, 24, 24, 27, 28, 29, 30}));
|
||||
}
|
||||
|
||||
auto set_normalization_only = [](SentencePieceNormalizer *normalizer) {
|
||||
SentencePieceTrainer::SetProtoField("add_dummy_prefix", "false",
|
||||
normalizer->mutable_normalizer_spec());
|
||||
SentencePieceTrainer::SetProtoField("escape_whitespaces", "false",
|
||||
normalizer->mutable_normalizer_spec());
|
||||
SentencePieceTrainer::SetProtoField("remove_extra_whitespaces", "false",
|
||||
normalizer->mutable_normalizer_spec());
|
||||
};
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_OK(sp.Load(model_file));
|
||||
set_normalization_only(&sp);
|
||||
EXPECT_EQ(sp.Normalize("KADOKAWA ABC "), "KADOKAWA ABC ");
|
||||
}
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_OK(sp.LoadFromRuleTSV(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), "nfkc_cf.tsv")));
|
||||
set_normalization_only(&sp);
|
||||
EXPECT_EQ(sp.Normalize("ABCD"), "abcd");
|
||||
}
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_FALSE(sp.LoadFromRuleTSV("__unknown__").ok());
|
||||
}
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_OK(sp.LoadFromRuleName("nfkc_cf"));
|
||||
set_normalization_only(&sp);
|
||||
EXPECT_EQ(sp.Normalize("ABCD"), "abcd");
|
||||
}
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_OK(sp.LoadFromRuleName("identity"));
|
||||
set_normalization_only(&sp);
|
||||
EXPECT_EQ(sp.Normalize("ABCD"), "ABCD");
|
||||
}
|
||||
|
||||
{
|
||||
SentencePieceNormalizer sp;
|
||||
EXPECT_FALSE(sp.LoadFromRuleName("__unknown__").ok());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace sentencepiece
|
||||
|
Loading…
Reference in New Issue
Block a user