mirror of
https://github.com/google/sentencepiece.git
synced 2024-08-16 14:21:00 +03:00
validate the range of piece in Python module
This commit is contained in:
parent
d8c4b04056
commit
910f804f72
1
.gitignore
vendored
1
.gitignore
vendored
@ -71,3 +71,4 @@ cmake_install.cmake
|
||||
libsentencepiece.so*
|
||||
libsentencepiece_train.so*
|
||||
python/bundled
|
||||
_sentencepiece.*.so
|
||||
|
@ -116,9 +116,6 @@ class SentencePieceProcessor(object):
|
||||
def DecodePieces(self, pieces):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodePieces(self, pieces)
|
||||
|
||||
def DecodeIds(self, ids):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids)
|
||||
|
||||
def EncodeAsSerializedProto(self, input):
|
||||
return _sentencepiece.SentencePieceProcessor_EncodeAsSerializedProto(self, input)
|
||||
|
||||
@ -131,9 +128,6 @@ class SentencePieceProcessor(object):
|
||||
def DecodePiecesAsSerializedProto(self, pieces):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodePiecesAsSerializedProto(self, pieces)
|
||||
|
||||
def DecodeIdsAsSerializedProto(self, ids):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProto(self, ids)
|
||||
|
||||
def GetPieceSize(self):
|
||||
return _sentencepiece.SentencePieceProcessor_GetPieceSize(self)
|
||||
|
||||
@ -176,6 +170,12 @@ class SentencePieceProcessor(object):
|
||||
def LoadFromFile(self, arg):
|
||||
return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
|
||||
|
||||
def DecodeIdsWithCheck(self, ids):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodeIdsWithCheck(self, ids)
|
||||
|
||||
def DecodeIdsAsSerializedProtoWithCheck(self, ids):
|
||||
return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(self, ids)
|
||||
|
||||
def Init(self,
|
||||
model_file=None,
|
||||
model_proto=None,
|
||||
@ -242,8 +242,8 @@ class SentencePieceProcessor(object):
|
||||
nbest_size < 0: assuming that nbest_size is infinite and samples
|
||||
from the all hypothesis (lattice) using
|
||||
forward-filtering-and-backward-sampling algorithm.
|
||||
alpha: Soothing parameter for unigram sampling, and dropout probability of
|
||||
merge operations for BPE-dropout.
|
||||
alpha: Soothing parameter for unigram sampling, and merge probability for
|
||||
BPE-dropout (probablity 'p' in BPE-dropout paper).
|
||||
"""
|
||||
|
||||
if out_type is None:
|
||||
@ -262,12 +262,12 @@ class SentencePieceProcessor(object):
|
||||
alpha = self._alpha
|
||||
|
||||
if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
|
||||
nbest_size == 1 or alpha is None or
|
||||
alpha <= 0.0 or alpha > 1.0):
|
||||
nbest_size == 1 or alpha is None):
|
||||
raise RuntimeError(
|
||||
'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
|
||||
'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and '
|
||||
'samples from all candidates on the lattice instead of nbest segmentations. '
|
||||
'and "alpha". "nbest_size" is enabled only on unigram mode ignored in BPE-dropout. '
|
||||
'when "nbest_size = -1" , this method samples from all candidates on the lattice '
|
||||
'instead of nbest segmentations.'
|
||||
)
|
||||
|
||||
def _encode(text):
|
||||
@ -310,7 +310,7 @@ class SentencePieceProcessor(object):
|
||||
if not input:
|
||||
return self.DecodeIds([])
|
||||
elif type(input) is int:
|
||||
return self.DecodeIds([input])
|
||||
return self.DecodeIdsWithCheck([input])
|
||||
elif type(input) is str:
|
||||
return self.DecodePieces([input])
|
||||
|
||||
@ -318,7 +318,7 @@ class SentencePieceProcessor(object):
|
||||
if not input:
|
||||
return self.DecodeIds([])
|
||||
if type(input[0]) is int:
|
||||
return self.DecodeIds(input)
|
||||
return self.DecodeIdsWithCheck(input)
|
||||
return self.DecodePieces(input)
|
||||
|
||||
if type(input[0]) is list:
|
||||
@ -486,12 +486,16 @@ def _add_snake_case(classname):
|
||||
def _batchnize(classname, name):
|
||||
"""Enables batch request for the method classname.name."""
|
||||
func = getattr(classname, name, None)
|
||||
def _func(v, n):
|
||||
if type(n) is int and (n < 0 or n >= v.piece_size()):
|
||||
raise IndexError('piece id is out of range.')
|
||||
return func(v, n)
|
||||
|
||||
def _batched_func(self, arg):
|
||||
if type(arg) is list:
|
||||
return [func(self, n) for n in arg]
|
||||
return [_func(self, n) for n in arg]
|
||||
else:
|
||||
return func(self, arg)
|
||||
return _func(self, arg)
|
||||
|
||||
setattr(classname, name, _batched_func)
|
||||
|
||||
@ -501,6 +505,8 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
|
||||
|
||||
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
|
||||
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
|
||||
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
|
||||
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck
|
||||
|
||||
for m in [
|
||||
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
|
||||
|
Binary file not shown.
@ -176,6 +176,8 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
|
||||
%ignore sentencepiece::SentencePieceProcessor::SampleEncode;
|
||||
%ignore sentencepiece::SentencePieceProcessor::NBestEncode;
|
||||
%ignore sentencepiece::SentencePieceProcessor::Decode;
|
||||
%ignore sentencepiece::SentencePieceProcessor::DecodeIds;
|
||||
%ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsSerializedProto;
|
||||
%ignore sentencepiece::SentencePieceProcessor::model_proto;
|
||||
%ignore sentencepiece::SentencePieceProcessor::Load;
|
||||
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
|
||||
@ -196,6 +198,26 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
|
||||
return $self->Load(arg);
|
||||
}
|
||||
|
||||
std::string DecodeIdsWithCheck(
|
||||
const std::vector<int> &ids) const {
|
||||
for (int id : ids)
|
||||
if (id < 0 || id >= $self->GetPieceSize())
|
||||
throw sentencepiece::util::Status(
|
||||
sentencepiece::util::StatusCode::kOutOfRange,
|
||||
"piece id is out of range.");
|
||||
return $self->DecodeIds(ids);
|
||||
}
|
||||
|
||||
util::bytes DecodeIdsAsSerializedProtoWithCheck(
|
||||
const std::vector<int> &ids) const {
|
||||
for (int id : ids)
|
||||
if (id < 0 || id >= $self->GetPieceSize())
|
||||
throw sentencepiece::util::Status(
|
||||
sentencepiece::util::StatusCode::kOutOfRange,
|
||||
"piece id is out of range.");
|
||||
return $self->DecodeIdsAsSerializedProto(ids);
|
||||
}
|
||||
|
||||
%pythoncode {
|
||||
def Init(self,
|
||||
model_file=None,
|
||||
@ -264,7 +286,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
|
||||
from the all hypothesis (lattice) using
|
||||
forward-filtering-and-backward-sampling algorithm.
|
||||
alpha: Soothing parameter for unigram sampling, and merge probability for
|
||||
BPE-dropout.
|
||||
BPE-dropout (probablity 'p' in BPE-dropout paper).
|
||||
"""
|
||||
|
||||
if out_type is None:
|
||||
@ -283,12 +305,12 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
|
||||
alpha = self._alpha
|
||||
|
||||
if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
|
||||
nbest_size == 1 or alpha is None or
|
||||
alpha <= 0.0 or alpha > 1.0):
|
||||
nbest_size == 1 or alpha is None):
|
||||
raise RuntimeError(
|
||||
'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
|
||||
'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and '
|
||||
'samples from all candidates on the lattice instead of nbest segmentations. '
|
||||
'and "alpha". "nbest_size" is enabled only on unigram mode ignored in BPE-dropout. '
|
||||
'when "nbest_size = -1" , this method samples from all candidates on the lattice '
|
||||
'instead of nbest segmentations.'
|
||||
)
|
||||
|
||||
def _encode(text):
|
||||
@ -331,7 +353,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
|
||||
if not input:
|
||||
return self.DecodeIds([])
|
||||
elif type(input) is int:
|
||||
return self.DecodeIds([input])
|
||||
return self.DecodeIdsWithCheck([input])
|
||||
elif type(input) is str:
|
||||
return self.DecodePieces([input])
|
||||
|
||||
@ -339,7 +361,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator {
|
||||
if not input:
|
||||
return self.DecodeIds([])
|
||||
if type(input[0]) is int:
|
||||
return self.DecodeIds(input)
|
||||
return self.DecodeIdsWithCheck(input)
|
||||
return self.DecodePieces(input)
|
||||
|
||||
if type(input[0]) is list:
|
||||
@ -688,12 +710,16 @@ def _add_snake_case(classname):
|
||||
def _batchnize(classname, name):
|
||||
"""Enables batch request for the method classname.name."""
|
||||
func = getattr(classname, name, None)
|
||||
def _func(v, n):
|
||||
if type(n) is int and (n < 0 or n >= v.piece_size()):
|
||||
raise IndexError('piece id is out of range.')
|
||||
return func(v, n)
|
||||
|
||||
def _batched_func(self, arg):
|
||||
if type(arg) is list:
|
||||
return [func(self, n) for n in arg]
|
||||
return [_func(self, n) for n in arg]
|
||||
else:
|
||||
return func(self, arg)
|
||||
return _func(self, arg)
|
||||
|
||||
setattr(classname, name, _batched_func)
|
||||
|
||||
@ -703,6 +729,8 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
|
||||
|
||||
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
|
||||
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
|
||||
SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck
|
||||
SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck
|
||||
|
||||
for m in [
|
||||
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
|
||||
|
@ -3285,6 +3285,22 @@ SWIGINTERNINLINE PyObject*
|
||||
SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_LoadFromFile(sentencepiece::SentencePieceProcessor *self,absl::string_view arg){
|
||||
return self->Load(arg);
|
||||
}
|
||||
SWIGINTERN std::string sentencepiece_SentencePieceProcessor_DecodeIdsWithCheck(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){
|
||||
for (int id : ids)
|
||||
if (id < 0 || id >= self->GetPieceSize())
|
||||
throw sentencepiece::util::Status(
|
||||
sentencepiece::util::StatusCode::kOutOfRange,
|
||||
"piece id is out of range.");
|
||||
return self->DecodeIds(ids);
|
||||
}
|
||||
SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){
|
||||
for (int id : ids)
|
||||
if (id < 0 || id >= self->GetPieceSize())
|
||||
throw sentencepiece::util::Status(
|
||||
sentencepiece::util::StatusCode::kOutOfRange,
|
||||
"piece id is out of range.");
|
||||
return self->DecodeIdsAsSerializedProto(ids);
|
||||
}
|
||||
SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){
|
||||
const auto _status = sentencepiece::SentencePieceTrainer::Train(arg);
|
||||
if (!_status.ok()) throw _status;
|
||||
@ -4117,66 +4133,6 @@ fail:
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIds(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
|
||||
std::vector< int > *arg2 = 0 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
std::string result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIds", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIds" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
|
||||
{
|
||||
std::vector<int> *out = nullptr;
|
||||
if (PyList_Check(swig_obj[1])) {
|
||||
const size_t size = PyList_Size(swig_obj[1]);
|
||||
out = new std::vector<int>(size);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
PyObject *o = PyList_GetItem(swig_obj[1], i);
|
||||
if (PyInt_Check(o)) {
|
||||
(*out)[i] = static_cast<int>(PyInt_AsLong(o));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"list must contain integers");
|
||||
SWIG_fail;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"not a list");
|
||||
SWIG_fail;
|
||||
}
|
||||
arg2 = out;
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = ((sentencepiece::SentencePieceProcessor const *)arg1)->DecodeIds((std::vector< int > const &)*arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
PyObject *input_type = resultobj;
|
||||
resultobj = MakePyOutputString(result, input_type);
|
||||
}
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_EncodeAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
|
||||
@ -4387,65 +4343,6 @@ fail:
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
|
||||
std::vector< int > *arg2 = 0 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
sentencepiece::util::bytes result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsAsSerializedProto", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsAsSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
|
||||
{
|
||||
std::vector<int> *out = nullptr;
|
||||
if (PyList_Check(swig_obj[1])) {
|
||||
const size_t size = PyList_Size(swig_obj[1]);
|
||||
out = new std::vector<int>(size);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
PyObject *o = PyList_GetItem(swig_obj[1], i);
|
||||
if (PyInt_Check(o)) {
|
||||
(*out)[i] = static_cast<int>(PyInt_AsLong(o));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"list must contain integers");
|
||||
SWIG_fail;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"not a list");
|
||||
SWIG_fail;
|
||||
}
|
||||
arg2 = out;
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = ((sentencepiece::SentencePieceProcessor const *)arg1)->DecodeIdsAsSerializedProto((std::vector< int > const &)*arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
resultobj = MakePyOutputBytes(result);
|
||||
}
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_GetPieceSize(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
|
||||
@ -4950,6 +4847,125 @@ fail:
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsWithCheck(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
|
||||
std::vector< int > *arg2 = 0 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
std::string result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsWithCheck", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsWithCheck" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
|
||||
{
|
||||
std::vector<int> *out = nullptr;
|
||||
if (PyList_Check(swig_obj[1])) {
|
||||
const size_t size = PyList_Size(swig_obj[1]);
|
||||
out = new std::vector<int>(size);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
PyObject *o = PyList_GetItem(swig_obj[1], i);
|
||||
if (PyInt_Check(o)) {
|
||||
(*out)[i] = static_cast<int>(PyInt_AsLong(o));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"list must contain integers");
|
||||
SWIG_fail;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"not a list");
|
||||
SWIG_fail;
|
||||
}
|
||||
arg2 = out;
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = sentencepiece_SentencePieceProcessor_DecodeIdsWithCheck((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< int > const &)*arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
PyObject *input_type = resultobj;
|
||||
resultobj = MakePyOutputString(result, input_type);
|
||||
}
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *resultobj = 0;
|
||||
sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ;
|
||||
std::vector< int > *arg2 = 0 ;
|
||||
void *argp1 = 0 ;
|
||||
int res1 = 0 ;
|
||||
PyObject *swig_obj[2] ;
|
||||
sentencepiece::util::bytes result;
|
||||
|
||||
if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", 2, 2, swig_obj)) SWIG_fail;
|
||||
res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 );
|
||||
if (!SWIG_IsOK(res1)) {
|
||||
SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'");
|
||||
}
|
||||
arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1);
|
||||
{
|
||||
std::vector<int> *out = nullptr;
|
||||
if (PyList_Check(swig_obj[1])) {
|
||||
const size_t size = PyList_Size(swig_obj[1]);
|
||||
out = new std::vector<int>(size);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
PyObject *o = PyList_GetItem(swig_obj[1], i);
|
||||
if (PyInt_Check(o)) {
|
||||
(*out)[i] = static_cast<int>(PyInt_AsLong(o));
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"list must contain integers");
|
||||
SWIG_fail;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PyErr_SetString(PyExc_TypeError,"not a list");
|
||||
SWIG_fail;
|
||||
}
|
||||
arg2 = out;
|
||||
}
|
||||
{
|
||||
try {
|
||||
result = sentencepiece_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< int > const &)*arg2);
|
||||
ReleaseResultObject(resultobj);
|
||||
}
|
||||
catch (const sentencepiece::util::Status &status) {
|
||||
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
|
||||
}
|
||||
}
|
||||
{
|
||||
resultobj = MakePyOutputBytes(result);
|
||||
}
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return resultobj;
|
||||
fail:
|
||||
{
|
||||
delete arg2;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
SWIGINTERN PyObject *SentencePieceProcessor_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) {
|
||||
PyObject *obj;
|
||||
if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL;
|
||||
@ -5269,12 +5285,10 @@ static PyMethodDef SwigMethods[] = {
|
||||
{ "SentencePieceProcessor_SampleEncodeAsPieces", _wrap_SentencePieceProcessor_SampleEncodeAsPieces, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_SampleEncodeAsIds", _wrap_SentencePieceProcessor_SampleEncodeAsIds, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_DecodePieces", _wrap_SentencePieceProcessor_DecodePieces, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_DecodeIds", _wrap_SentencePieceProcessor_DecodeIds, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_EncodeAsSerializedProto", _wrap_SentencePieceProcessor_EncodeAsSerializedProto, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_SampleEncodeAsSerializedProto", _wrap_SentencePieceProcessor_SampleEncodeAsSerializedProto, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_NBestEncodeAsSerializedProto", _wrap_SentencePieceProcessor_NBestEncodeAsSerializedProto, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_DecodePiecesAsSerializedProto", _wrap_SentencePieceProcessor_DecodePiecesAsSerializedProto, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_DecodeIdsAsSerializedProto", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProto, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_GetPieceSize", _wrap_SentencePieceProcessor_GetPieceSize, METH_O, NULL},
|
||||
{ "SentencePieceProcessor_PieceToId", _wrap_SentencePieceProcessor_PieceToId, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_IdToPiece", _wrap_SentencePieceProcessor_IdToPiece, METH_VARARGS, NULL},
|
||||
@ -5289,6 +5303,8 @@ static PyMethodDef SwigMethods[] = {
|
||||
{ "SentencePieceProcessor_pad_id", _wrap_SentencePieceProcessor_pad_id, METH_O, NULL},
|
||||
{ "SentencePieceProcessor_serialized_model_proto", _wrap_SentencePieceProcessor_serialized_model_proto, METH_O, NULL},
|
||||
{ "SentencePieceProcessor_LoadFromFile", _wrap_SentencePieceProcessor_LoadFromFile, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_DecodeIdsWithCheck", _wrap_SentencePieceProcessor_DecodeIdsWithCheck, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck, METH_VARARGS, NULL},
|
||||
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
|
||||
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
|
||||
{ "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL},
|
||||
|
@ -372,6 +372,22 @@ class TestSentencepieceProcessor(unittest.TestCase):
|
||||
++ids2[' '.join(sp.encode('hello world', enable_sampling=False))]
|
||||
self.assertEqual(len(ids2), 1)
|
||||
|
||||
def test_valid_range(self):
|
||||
size = self.sp_.piece_size()
|
||||
funcs = [
|
||||
'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', 'IsByte',
|
||||
'DecodeIds', 'DecodeIdsAsSerializedProto'
|
||||
]
|
||||
for m in funcs:
|
||||
getattr(self.sp_, m)([10, 20, 30])
|
||||
|
||||
for m in funcs:
|
||||
try:
|
||||
getattr(self.sp_, m)([size])
|
||||
self.assertTrue(False)
|
||||
except:
|
||||
self.assertTrue(True)
|
||||
|
||||
|
||||
def suite():
|
||||
suite = unittest.TestSuite()
|
||||
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "bpe_model.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
@ -19,7 +21,6 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "bpe_model.h"
|
||||
#include "freelist.h"
|
||||
#include "third_party/absl/container/flat_hash_map.h"
|
||||
#include "util.h"
|
||||
@ -132,6 +133,7 @@ std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
|
||||
std::mt19937 *rand_gen = nullptr;
|
||||
auto skip_merge = [&]() {
|
||||
if (alpha <= 0.0) return false;
|
||||
if (alpha >= 1.0) return true;
|
||||
if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator();
|
||||
std::uniform_real_distribution<> gen(0.0, 1.0);
|
||||
return gen(*rand_gen) < alpha;
|
||||
|
@ -493,6 +493,12 @@ class SentencePieceProcessor {
|
||||
std::vector<ExtraOption> decode_extra_options_;
|
||||
};
|
||||
|
||||
// Set seed value of random generator.
|
||||
// Do not set static_cast<unique_int>(-1),
|
||||
// as this seed is reserved for initializing from
|
||||
// std::random_device.
|
||||
void SetRandomGeneratorSeed(unsigned int seed);
|
||||
|
||||
#ifndef SWIG
|
||||
// IO related functions to absorb model formats.
|
||||
namespace io {
|
||||
|
@ -28,15 +28,16 @@
|
||||
#include "trainer_interface.h"
|
||||
|
||||
ABSL_FLAG(std::string, model, "", "model file name");
|
||||
ABSL_FLAG(
|
||||
std::string, output_format, "piece",
|
||||
"choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto");
|
||||
ABSL_FLAG(std::string, output_format, "piece",
|
||||
"choose from piece, id, proto, nbest_piece, nbest_id, nbest_proto,
|
||||
"sample_piece, sample_id or sample_proto.");
|
||||
ABSL_FLAG(std::string, input, "", "input filename");
|
||||
ABSL_FLAG(std::string, output, "", "output filename");
|
||||
ABSL_FLAG(std::string, extra_options, "",
|
||||
"':' separated encoder extra options, e.g., \"reverse:bos:eos\"");
|
||||
ABSL_FLAG(int32, nbest_size, 10, "NBest size");
|
||||
ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode.");
|
||||
ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator.");
|
||||
|
||||
// Piece restriction with vocabulary file.
|
||||
// https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt
|
||||
|
19
src/util.cc
19
src/util.cc
@ -12,11 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "util.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
constexpr unsigned int kDefaultSeed = static_cast<unsigned int>(-1);
|
||||
static unsigned int g_seed = kDefaultSeed;
|
||||
} // namespace
|
||||
|
||||
void SetRandomGeneratorSeed(unsigned int seed) {
|
||||
if (seed != kDefaultSeed) g_seed = seed;
|
||||
}
|
||||
|
||||
namespace string_util {
|
||||
|
||||
// mblen sotres the number of bytes consumed after decoding.
|
||||
@ -144,7 +153,8 @@ class RandomGeneratorStorage {
|
||||
std::mt19937 *Get() {
|
||||
auto *result = static_cast<std::mt19937 *>(pthread_getspecific(key_));
|
||||
if (result == nullptr) {
|
||||
result = new std::mt19937(std::random_device{}());
|
||||
result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}()
|
||||
: g_seed);
|
||||
pthread_setspecific(key_, result);
|
||||
}
|
||||
return result;
|
||||
@ -162,7 +172,8 @@ std::mt19937 *GetRandomGenerator() {
|
||||
}
|
||||
#else
|
||||
std::mt19937 *GetRandomGenerator() {
|
||||
thread_local static std::mt19937 mt(std::random_device{}());
|
||||
thread_local static std::mt19937 mt(
|
||||
g_seed == kDefaultSeed ? std::random_device{}() : g_seed);
|
||||
return &mt;
|
||||
}
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user