Added Normalization API

This commit is contained in:
Taku Kudo 2024-01-04 09:04:20 +00:00
parent e7b5260e4a
commit 06eee09847
6 changed files with 219 additions and 4 deletions

View File

@ -387,6 +387,12 @@ class SentencePieceProcessor(object):
def _SampleEncodeAndScoreAsImmutableProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece):
return _sentencepiece.SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto(self, text, num_samples, alpha, wor, include_best, add_bos, add_eos, reverse, emit_unk_piece)
def _Normalize(self, text):
return _sentencepiece.SentencePieceProcessor__Normalize(self, text)
def _NormalizeWithOffsets(self, text):
return _sentencepiece.SentencePieceProcessor__NormalizeWithOffsets(self, text)
def _CalculateEntropy(self, text, alpha):
return _sentencepiece.SentencePieceProcessor__CalculateEntropy(self, text, alpha)
@ -859,6 +865,17 @@ class SentencePieceProcessor(object):
return self._CalculateEntropy(input, alpha)
def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)
if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)
def piece_size(self):
return self.GetPieceSize()

View File

@ -347,6 +347,9 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
%ignore sentencepiece::SentencePieceProcessor::DecodePiecesAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::Normalize;
%ignore sentencepiece::SentencePieceProcessor::NormalizeWithOffsets;
%ignore sentencepiece::SentencePieceProcessor::model_proto;
%ignore sentencepiece::SentencePieceProcessor::Load;
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
@ -648,6 +651,16 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
return proto;
}
// Normalize
std::string _Normalize(absl::string_view text) {
return $self->Normalize(text);
}
std::pair<std::string, std::vector<size_t>> _NormalizeWithOffsets(absl::string_view text) {
std::pair<std::string, std::vector<size_t>> result;
$self->Normalize(text, &result.first, &result.second).IgnoreError();
return result;
}
// Calculate Entropy
float _CalculateEntropy(absl::string_view text, float alpha) {
@ -1020,12 +1033,12 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
def SampleEncodeAndScoreAsSerializedProto(self, input, num_samples=None, alpha=None, **kwargs):
return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
out_type='serialized_proto', **kwargs)
def SampleEncodeAndScoreAsImmutableProto(self, input, num_samples=None, alpha=None, **kwargs):
return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
out_type='immutable_proto', **kwargs)
def Decode(self, input, out_type=str, num_threads=None):
"""Decode processed id or token sequences.
@ -1140,6 +1153,17 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
return self._CalculateEntropy(input, alpha)
def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)
if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)
def piece_size(self):
return self.GetPieceSize()
@ -1315,7 +1339,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
def __init__(self, proto):
self.proto = proto
self.len = self.proto._pieces_size()
def __len__(self):
return self.len
@ -1383,7 +1407,7 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
@property
def nbests(self):
return ImmutableNBestSentencePieceText.ImmutableSentencePieceTextIterator(self)
def __eq__(self, other):
return self.SerializeAsString() == other.SerializeAsString()
@ -1654,6 +1678,15 @@ inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
}
}
%typemap(out) std::pair<std::string, std::vector<size_t>> {
PyObject *input_type = resultobj;
PyObject *obj = PyList_New($1.second.size());
for (size_t i = 0; i < $1.second.size(); ++i) {
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>($1.second[i])));
}
$result = PyTuple_Pack(2, MakePyOutputString($1.first, input_type), obj);
}
%typemap(in) sentencepiece::SentenceIterator * {
sentencepiece::SentenceIterator *out = nullptr;
if (PyIter_Check($input)) {

View File

@ -4004,6 +4004,14 @@ SWIGINTERN sentencepiece::ImmutableNBestSentencePieceText sentencepiece_Sentence
proto.ConvertToUnicodeSpans();
return proto;
}
SWIGINTERN std::string sentencepiece_SentencePieceProcessor__Normalize(sentencepiece::SentencePieceProcessor *self,absl::string_view text){
return self->Normalize(text);
}
SWIGINTERN std::pair< std::string,std::vector< size_t > > sentencepiece_SentencePieceProcessor__NormalizeWithOffsets(sentencepiece::SentencePieceProcessor *self,absl::string_view text){
std::pair<std::string, std::vector<size_t>> result;
self->Normalize(text, &result.first, &result.second).IgnoreError();
return result;
}
SWIGINTERN float sentencepiece_SentencePieceProcessor__CalculateEntropy(sentencepiece::SentencePieceProcessor *self,absl::string_view text,float alpha){
return self->CalculateEntropy(text, alpha);
}
@ -8261,6 +8269,96 @@ fail:
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor__Normalize(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
std::string result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__Normalize", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__Normalize" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = sentencepiece_SentencePieceProcessor__Normalize(arg1,SWIG_STD_MOVE(arg2));
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
PyObject *input_type = resultobj;
resultobj = MakePyOutputString(result, input_type);
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor__NormalizeWithOffsets(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
absl::string_view arg2 ;
void *argp1 = 0 ;
int res1 = 0 ;
PyObject *swig_obj[2] ;
std::pair< std::string,std::vector< size_t > > result;
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor__NormalizeWithOffsets", 2, 2, swig_obj)) SWIG_fail;
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
if (!SWIG_IsOK(res1)) {
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor__NormalizeWithOffsets" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor *""'");
}
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
{
const PyInputString ustring(swig_obj[1]);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
arg2 = ustring.str();
}
{
try {
result = sentencepiece_SentencePieceProcessor__NormalizeWithOffsets(arg1,SWIG_STD_MOVE(arg2));
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
{
PyObject *input_type = resultobj;
PyObject *obj = PyList_New((&result)->second.size());
for (size_t i = 0; i < (&result)->second.size(); ++i) {
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>((&result)->second[i])));
}
resultobj = PyTuple_Pack(2, MakePyOutputString((&result)->first, input_type), obj);
}
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceProcessor__CalculateEntropy(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
@ -8825,6 +8923,8 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceProcessor__SampleEncodeAndScoreAsPieces", _wrap_SentencePieceProcessor__SampleEncodeAndScoreAsPieces, METH_VARARGS, NULL},
{ "SentencePieceProcessor__SampleEncodeAndScoreAsSerializedProto", _wrap_SentencePieceProcessor__SampleEncodeAndScoreAsSerializedProto, METH_VARARGS, NULL},
{ "SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto", _wrap_SentencePieceProcessor__SampleEncodeAndScoreAsImmutableProto, METH_VARARGS, NULL},
{ "SentencePieceProcessor__Normalize", _wrap_SentencePieceProcessor__Normalize, METH_VARARGS, NULL},
{ "SentencePieceProcessor__NormalizeWithOffsets", _wrap_SentencePieceProcessor__NormalizeWithOffsets, METH_VARARGS, NULL},
{ "SentencePieceProcessor__CalculateEntropy", _wrap_SentencePieceProcessor__CalculateEntropy, METH_VARARGS, NULL},
{ "SentencePieceProcessor__CalculateEntropyBatch", _wrap_SentencePieceProcessor__CalculateEntropyBatch, METH_VARARGS, NULL},
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},

View File

@ -760,6 +760,36 @@ class TestSentencepieceProcessor(unittest.TestCase):
spm.set_random_generator_seed(1)
spm.set_min_log_level(3)
def test_normalize(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model')
)
self.assertEqual('▁KADOKAWAABC', sp.normalize('ABC'))
self.assertEqual('▁KADOKAWAABC', sp.Normalize('ABC'))
x = sp.Normalize('ABC', with_offsets=True)
self.assertEqual('▁KADOKAWAABC', x[0])
self.assertEqual(
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1]
)
self.assertEqual(
['▁KADOKAWAABC', '▁平成'], sp.normalize(['ABC', ''])
)
self.assertEqual(
['▁KADOKAWAABC', '▁平成'], sp.Normalize(['ABC', ''])
)
x = sp.Normalize(['ABC', ''], with_offsets=True)
self.assertEqual(len(x), 2)
self.assertEqual('▁KADOKAWAABC', x[0][0])
self.assertEqual(
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1]
)
self.assertEqual('▁平成', x[1][0])
self.assertEqual([0, 0, 0, 0, 0, 0, 0, 0, 0, 3], x[1][1])
def suite():
suite = unittest.TestSuite()

View File

@ -931,6 +931,26 @@ util::Status SentencePieceProcessor::Decode(const std::vector<int> &ids,
return value; \
}
util::Status SentencePieceProcessor::Normalize(absl::string_view input,
std::string *normalized) const {
std::vector<size_t> norm_to_orig;
CHECK_OR_RETURN(normalizer_);
return normalizer_->Normalize(input, normalized, &norm_to_orig);
}
util::Status SentencePieceProcessor::Normalize(
absl::string_view input, std::string *normalized,
std::vector<size_t> *norm_to_orig) const {
CHECK_OR_RETURN(normalizer_);
return normalizer_->Normalize(input, normalized, norm_to_orig);
}
std::string SentencePieceProcessor::Normalize(absl::string_view input) const {
std::string normalized;
Normalize(input, &normalized).IgnoreError();
return normalized;
}
int SentencePieceProcessor::GetPieceSize() const {
CHECK_STATUS_OR_RETURN_DEFAULT(0);
return model_->GetPieceSize();

View File

@ -614,6 +614,21 @@ class SentencePieceProcessor {
#undef DEFINE_SPP_SERIALIZED_PROTO_IMPL
#undef DEFINE_SPP_IMMUTABLE_PROTO_IMPL
//////////////////////////////////////////////////////////////
// Normalization methods.
// Normalize `input`.
virtual util::Status Normalize(absl::string_view input,
std::string *normalized) const;
// Normalize `input`. Stores the utf8-byte offset from
// the normalized string to the original input.
virtual util::Status Normalize(absl::string_view input,
std::string *normalized,
std::vector<size_t> *norm_to_orig) const;
virtual std::string Normalize(absl::string_view input) const;
//////////////////////////////////////////////////////////////
// Vocabulary management methods.
//