diff --git a/.gitignore b/.gitignore index 20073e7..aac6692 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,4 @@ cmake_install.cmake libsentencepiece.so* libsentencepiece_train.so* python/bundled +_sentencepiece.*.so diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index e704a2a..001ffc7 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -116,9 +116,6 @@ class SentencePieceProcessor(object): def DecodePieces(self, pieces): return _sentencepiece.SentencePieceProcessor_DecodePieces(self, pieces) - def DecodeIds(self, ids): - return _sentencepiece.SentencePieceProcessor_DecodeIds(self, ids) - def EncodeAsSerializedProto(self, input): return _sentencepiece.SentencePieceProcessor_EncodeAsSerializedProto(self, input) @@ -131,9 +128,6 @@ class SentencePieceProcessor(object): def DecodePiecesAsSerializedProto(self, pieces): return _sentencepiece.SentencePieceProcessor_DecodePiecesAsSerializedProto(self, pieces) - def DecodeIdsAsSerializedProto(self, ids): - return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProto(self, ids) - def GetPieceSize(self): return _sentencepiece.SentencePieceProcessor_GetPieceSize(self) @@ -176,6 +170,12 @@ class SentencePieceProcessor(object): def LoadFromFile(self, arg): return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg) + def DecodeIdsWithCheck(self, ids): + return _sentencepiece.SentencePieceProcessor_DecodeIdsWithCheck(self, ids) + + def DecodeIdsAsSerializedProtoWithCheck(self, ids): + return _sentencepiece.SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(self, ids) + def Init(self, model_file=None, model_proto=None, @@ -242,8 +242,8 @@ class SentencePieceProcessor(object): nbest_size < 0: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) using forward-filtering-and-backward-sampling algorithm. - alpha: Soothing parameter for unigram sampling, and dropout probability of - merge operations for BPE-dropout. + alpha: Soothing parameter for unigram sampling, and merge probability for + BPE-dropout (probablity 'p' in BPE-dropout paper). """ if out_type is None: @@ -262,12 +262,12 @@ class SentencePieceProcessor(object): alpha = self._alpha if enable_sampling == True and (nbest_size is None or nbest_size == 0 or - nbest_size == 1 or alpha is None or - alpha <= 0.0 or alpha > 1.0): + nbest_size == 1 or alpha is None): raise RuntimeError( 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", ' - 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and ' - 'samples from all candidates on the lattice instead of nbest segmentations. ' + 'and "alpha". "nbest_size" is enabled only on unigram mode ignored in BPE-dropout. ' + 'when "nbest_size = -1" , this method samples from all candidates on the lattice ' + 'instead of nbest segmentations.' ) def _encode(text): @@ -310,7 +310,7 @@ class SentencePieceProcessor(object): if not input: return self.DecodeIds([]) elif type(input) is int: - return self.DecodeIds([input]) + return self.DecodeIdsWithCheck([input]) elif type(input) is str: return self.DecodePieces([input]) @@ -318,7 +318,7 @@ class SentencePieceProcessor(object): if not input: return self.DecodeIds([]) if type(input[0]) is int: - return self.DecodeIds(input) + return self.DecodeIdsWithCheck(input) return self.DecodePieces(input) if type(input[0]) is list: @@ -486,12 +486,16 @@ def _add_snake_case(classname): def _batchnize(classname, name): """Enables batch request for the method classname.name.""" func = getattr(classname, name, None) + def _func(v, n): + if type(n) is int and (n < 0 or n >= v.piece_size()): + raise IndexError('piece id is out of range.') + return func(v, n) def _batched_func(self, arg): if type(arg) is list: - return [func(self, n) for n in arg] + return [_func(self, n) for n in arg] else: - return func(self, arg) + return _func(self, arg) setattr(classname, name, _batched_func) @@ -501,6 +505,8 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init) SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode +SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck +SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck for m in [ 'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', diff --git a/python/src/sentencepiece/_sentencepiece.cpython-38-x86_64-linux-gnu.so b/python/src/sentencepiece/_sentencepiece.cpython-38-x86_64-linux-gnu.so deleted file mode 100755 index 30c3c3c..0000000 Binary files a/python/src/sentencepiece/_sentencepiece.cpython-38-x86_64-linux-gnu.so and /dev/null differ diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index 04f3af0..6522d1f 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -176,6 +176,8 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { %ignore sentencepiece::SentencePieceProcessor::SampleEncode; %ignore sentencepiece::SentencePieceProcessor::NBestEncode; %ignore sentencepiece::SentencePieceProcessor::Decode; +%ignore sentencepiece::SentencePieceProcessor::DecodeIds; +%ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsSerializedProto; %ignore sentencepiece::SentencePieceProcessor::model_proto; %ignore sentencepiece::SentencePieceProcessor::Load; %ignore sentencepiece::SentencePieceProcessor::LoadOrDie; @@ -196,6 +198,26 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { return $self->Load(arg); } + std::string DecodeIdsWithCheck( + const std::vector &ids) const { + for (int id : ids) + if (id < 0 || id >= $self->GetPieceSize()) + throw sentencepiece::util::Status( + sentencepiece::util::StatusCode::kOutOfRange, + "piece id is out of range."); + return $self->DecodeIds(ids); + } + + util::bytes DecodeIdsAsSerializedProtoWithCheck( + const std::vector &ids) const { + for (int id : ids) + if (id < 0 || id >= $self->GetPieceSize()) + throw sentencepiece::util::Status( + sentencepiece::util::StatusCode::kOutOfRange, + "piece id is out of range."); + return $self->DecodeIdsAsSerializedProto(ids); + } + %pythoncode { def Init(self, model_file=None, @@ -264,7 +286,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { from the all hypothesis (lattice) using forward-filtering-and-backward-sampling algorithm. alpha: Soothing parameter for unigram sampling, and merge probability for - BPE-dropout. + BPE-dropout (probablity 'p' in BPE-dropout paper). """ if out_type is None: @@ -283,12 +305,12 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { alpha = self._alpha if enable_sampling == True and (nbest_size is None or nbest_size == 0 or - nbest_size == 1 or alpha is None or - alpha <= 0.0 or alpha > 1.0): + nbest_size == 1 or alpha is None): raise RuntimeError( 'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", ' - 'and "0.0 < alpha < 1.0". "nbest_size = -1" is enabled only on unigram mode and ' - 'samples from all candidates on the lattice instead of nbest segmentations. ' + 'and "alpha". "nbest_size" is enabled only on unigram mode ignored in BPE-dropout. ' + 'when "nbest_size = -1" , this method samples from all candidates on the lattice ' + 'instead of nbest segmentations.' ) def _encode(text): @@ -331,7 +353,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { if not input: return self.DecodeIds([]) elif type(input) is int: - return self.DecodeIds([input]) + return self.DecodeIdsWithCheck([input]) elif type(input) is str: return self.DecodePieces([input]) @@ -339,7 +361,7 @@ class PySentenceIterator : public sentencepiece::SentenceIterator { if not input: return self.DecodeIds([]) if type(input[0]) is int: - return self.DecodeIds(input) + return self.DecodeIdsWithCheck(input) return self.DecodePieces(input) if type(input[0]) is list: @@ -688,12 +710,16 @@ def _add_snake_case(classname): def _batchnize(classname, name): """Enables batch request for the method classname.name.""" func = getattr(classname, name, None) + def _func(v, n): + if type(n) is int and (n < 0 or n >= v.piece_size()): + raise IndexError('piece id is out of range.') + return func(v, n) def _batched_func(self, arg): if type(arg) is list: - return [func(self, n) for n in arg] + return [_func(self, n) for n in arg] else: - return func(self, arg) + return _func(self, arg) setattr(classname, name, _batched_func) @@ -703,6 +729,8 @@ setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init) SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode +SentencePieceProcessor.DecodeIds = SentencePieceProcessor.DecodeIdsWithCheck +SentencePieceProcessor.DecodeIdsAsSerializedProto = SentencePieceProcessor.DecodeIdsAsSerializedProtoWithCheck for m in [ 'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', diff --git a/python/src/sentencepiece/sentencepiece_wrap.cxx b/python/src/sentencepiece/sentencepiece_wrap.cxx index bd7f6a1..7e2e85d 100644 --- a/python/src/sentencepiece/sentencepiece_wrap.cxx +++ b/python/src/sentencepiece/sentencepiece_wrap.cxx @@ -3285,6 +3285,22 @@ SWIGINTERNINLINE PyObject* SWIGINTERN sentencepiece::util::Status sentencepiece_SentencePieceProcessor_LoadFromFile(sentencepiece::SentencePieceProcessor *self,absl::string_view arg){ return self->Load(arg); } +SWIGINTERN std::string sentencepiece_SentencePieceProcessor_DecodeIdsWithCheck(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){ + for (int id : ids) + if (id < 0 || id >= self->GetPieceSize()) + throw sentencepiece::util::Status( + sentencepiece::util::StatusCode::kOutOfRange, + "piece id is out of range."); + return self->DecodeIds(ids); + } +SWIGINTERN sentencepiece::util::bytes sentencepiece_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(sentencepiece::SentencePieceProcessor const *self,std::vector< int > const &ids){ + for (int id : ids) + if (id < 0 || id >= self->GetPieceSize()) + throw sentencepiece::util::Status( + sentencepiece::util::StatusCode::kOutOfRange, + "piece id is out of range."); + return self->DecodeIdsAsSerializedProto(ids); + } SWIGINTERN void sentencepiece_SentencePieceTrainer__TrainFromString(absl::string_view arg){ const auto _status = sentencepiece::SentencePieceTrainer::Train(arg); if (!_status.ok()) throw _status; @@ -4117,66 +4133,6 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIds(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { - PyObject *resultobj = 0; - sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; - std::vector< int > *arg2 = 0 ; - void *argp1 = 0 ; - int res1 = 0 ; - PyObject *swig_obj[2] ; - std::string result; - - if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIds", 2, 2, swig_obj)) SWIG_fail; - res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); - if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIds" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); - } - arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); - { - std::vector *out = nullptr; - if (PyList_Check(swig_obj[1])) { - const size_t size = PyList_Size(swig_obj[1]); - out = new std::vector(size); - for (size_t i = 0; i < size; ++i) { - PyObject *o = PyList_GetItem(swig_obj[1], i); - if (PyInt_Check(o)) { - (*out)[i] = static_cast(PyInt_AsLong(o)); - } else { - PyErr_SetString(PyExc_TypeError,"list must contain integers"); - SWIG_fail; - } - } - } else { - PyErr_SetString(PyExc_TypeError,"not a list"); - SWIG_fail; - } - arg2 = out; - } - { - try { - result = ((sentencepiece::SentencePieceProcessor const *)arg1)->DecodeIds((std::vector< int > const &)*arg2); - ReleaseResultObject(resultobj); - } - catch (const sentencepiece::util::Status &status) { - SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); - } - } - { - PyObject *input_type = resultobj; - resultobj = MakePyOutputString(result, input_type); - } - { - delete arg2; - } - return resultobj; -fail: - { - delete arg2; - } - return NULL; -} - - SWIGINTERN PyObject *_wrap_SentencePieceProcessor_EncodeAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -4387,65 +4343,6 @@ fail: } -SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsAsSerializedProto(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { - PyObject *resultobj = 0; - sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; - std::vector< int > *arg2 = 0 ; - void *argp1 = 0 ; - int res1 = 0 ; - PyObject *swig_obj[2] ; - sentencepiece::util::bytes result; - - if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsAsSerializedProto", 2, 2, swig_obj)) SWIG_fail; - res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); - if (!SWIG_IsOK(res1)) { - SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsAsSerializedProto" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); - } - arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); - { - std::vector *out = nullptr; - if (PyList_Check(swig_obj[1])) { - const size_t size = PyList_Size(swig_obj[1]); - out = new std::vector(size); - for (size_t i = 0; i < size; ++i) { - PyObject *o = PyList_GetItem(swig_obj[1], i); - if (PyInt_Check(o)) { - (*out)[i] = static_cast(PyInt_AsLong(o)); - } else { - PyErr_SetString(PyExc_TypeError,"list must contain integers"); - SWIG_fail; - } - } - } else { - PyErr_SetString(PyExc_TypeError,"not a list"); - SWIG_fail; - } - arg2 = out; - } - { - try { - result = ((sentencepiece::SentencePieceProcessor const *)arg1)->DecodeIdsAsSerializedProto((std::vector< int > const &)*arg2); - ReleaseResultObject(resultobj); - } - catch (const sentencepiece::util::Status &status) { - SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); - } - } - { - resultobj = MakePyOutputBytes(result); - } - { - delete arg2; - } - return resultobj; -fail: - { - delete arg2; - } - return NULL; -} - - SWIGINTERN PyObject *_wrap_SentencePieceProcessor_GetPieceSize(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *resultobj = 0; sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; @@ -4950,6 +4847,125 @@ fail: } +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsWithCheck(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::vector< int > *arg2 = 0 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject *swig_obj[2] ; + std::string result; + + if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsWithCheck", 2, 2, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsWithCheck" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::vector *out = nullptr; + if (PyList_Check(swig_obj[1])) { + const size_t size = PyList_Size(swig_obj[1]); + out = new std::vector(size); + for (size_t i = 0; i < size; ++i) { + PyObject *o = PyList_GetItem(swig_obj[1], i); + if (PyInt_Check(o)) { + (*out)[i] = static_cast(PyInt_AsLong(o)); + } else { + PyErr_SetString(PyExc_TypeError,"list must contain integers"); + SWIG_fail; + } + } + } else { + PyErr_SetString(PyExc_TypeError,"not a list"); + SWIG_fail; + } + arg2 = out; + } + { + try { + result = sentencepiece_SentencePieceProcessor_DecodeIdsWithCheck((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< int > const &)*arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + PyObject *input_type = resultobj; + resultobj = MakePyOutputString(result, input_type); + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + +SWIGINTERN PyObject *_wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { + PyObject *resultobj = 0; + sentencepiece::SentencePieceProcessor *arg1 = (sentencepiece::SentencePieceProcessor *) 0 ; + std::vector< int > *arg2 = 0 ; + void *argp1 = 0 ; + int res1 = 0 ; + PyObject *swig_obj[2] ; + sentencepiece::util::bytes result; + + if (!SWIG_Python_UnpackTuple(args, "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", 2, 2, swig_obj)) SWIG_fail; + res1 = SWIG_ConvertPtr(swig_obj[0], &argp1,SWIGTYPE_p_sentencepiece__SentencePieceProcessor, 0 | 0 ); + if (!SWIG_IsOK(res1)) { + SWIG_exception_fail(SWIG_ArgError(res1), "in method '" "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck" "', argument " "1"" of type '" "sentencepiece::SentencePieceProcessor const *""'"); + } + arg1 = reinterpret_cast< sentencepiece::SentencePieceProcessor * >(argp1); + { + std::vector *out = nullptr; + if (PyList_Check(swig_obj[1])) { + const size_t size = PyList_Size(swig_obj[1]); + out = new std::vector(size); + for (size_t i = 0; i < size; ++i) { + PyObject *o = PyList_GetItem(swig_obj[1], i); + if (PyInt_Check(o)) { + (*out)[i] = static_cast(PyInt_AsLong(o)); + } else { + PyErr_SetString(PyExc_TypeError,"list must contain integers"); + SWIG_fail; + } + } + } else { + PyErr_SetString(PyExc_TypeError,"not a list"); + SWIG_fail; + } + arg2 = out; + } + { + try { + result = sentencepiece_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck((sentencepiece::SentencePieceProcessor const *)arg1,(std::vector< int > const &)*arg2); + ReleaseResultObject(resultobj); + } + catch (const sentencepiece::util::Status &status) { + SWIG_exception(ToSwigError(status.code()), status.ToString().c_str()); + } + } + { + resultobj = MakePyOutputBytes(result); + } + { + delete arg2; + } + return resultobj; +fail: + { + delete arg2; + } + return NULL; +} + + SWIGINTERN PyObject *SentencePieceProcessor_swigregister(PyObject *SWIGUNUSEDPARM(self), PyObject *args) { PyObject *obj; if (!SWIG_Python_UnpackTuple(args, "swigregister", 1, 1, &obj)) return NULL; @@ -5269,12 +5285,10 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor_SampleEncodeAsPieces", _wrap_SentencePieceProcessor_SampleEncodeAsPieces, METH_VARARGS, NULL}, { "SentencePieceProcessor_SampleEncodeAsIds", _wrap_SentencePieceProcessor_SampleEncodeAsIds, METH_VARARGS, NULL}, { "SentencePieceProcessor_DecodePieces", _wrap_SentencePieceProcessor_DecodePieces, METH_VARARGS, NULL}, - { "SentencePieceProcessor_DecodeIds", _wrap_SentencePieceProcessor_DecodeIds, METH_VARARGS, NULL}, { "SentencePieceProcessor_EncodeAsSerializedProto", _wrap_SentencePieceProcessor_EncodeAsSerializedProto, METH_VARARGS, NULL}, { "SentencePieceProcessor_SampleEncodeAsSerializedProto", _wrap_SentencePieceProcessor_SampleEncodeAsSerializedProto, METH_VARARGS, NULL}, { "SentencePieceProcessor_NBestEncodeAsSerializedProto", _wrap_SentencePieceProcessor_NBestEncodeAsSerializedProto, METH_VARARGS, NULL}, { "SentencePieceProcessor_DecodePiecesAsSerializedProto", _wrap_SentencePieceProcessor_DecodePiecesAsSerializedProto, METH_VARARGS, NULL}, - { "SentencePieceProcessor_DecodeIdsAsSerializedProto", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProto, METH_VARARGS, NULL}, { "SentencePieceProcessor_GetPieceSize", _wrap_SentencePieceProcessor_GetPieceSize, METH_O, NULL}, { "SentencePieceProcessor_PieceToId", _wrap_SentencePieceProcessor_PieceToId, METH_VARARGS, NULL}, { "SentencePieceProcessor_IdToPiece", _wrap_SentencePieceProcessor_IdToPiece, METH_VARARGS, NULL}, @@ -5289,6 +5303,8 @@ static PyMethodDef SwigMethods[] = { { "SentencePieceProcessor_pad_id", _wrap_SentencePieceProcessor_pad_id, METH_O, NULL}, { "SentencePieceProcessor_serialized_model_proto", _wrap_SentencePieceProcessor_serialized_model_proto, METH_O, NULL}, { "SentencePieceProcessor_LoadFromFile", _wrap_SentencePieceProcessor_LoadFromFile, METH_VARARGS, NULL}, + { "SentencePieceProcessor_DecodeIdsWithCheck", _wrap_SentencePieceProcessor_DecodeIdsWithCheck, METH_VARARGS, NULL}, + { "SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck", _wrap_SentencePieceProcessor_DecodeIdsAsSerializedProtoWithCheck, METH_VARARGS, NULL}, { "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL}, { "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL}, { "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL}, diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 9264914..7bf1c13 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -372,6 +372,22 @@ class TestSentencepieceProcessor(unittest.TestCase): ++ids2[' '.join(sp.encode('hello world', enable_sampling=False))] self.assertEqual(len(ids2), 1) + def test_valid_range(self): + size = self.sp_.piece_size() + funcs = [ + 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', 'IsByte', + 'DecodeIds', 'DecodeIdsAsSerializedProto' + ] + for m in funcs: + getattr(self.sp_, m)([10, 20, 30]) + + for m in funcs: + try: + getattr(self.sp_, m)([size]) + self.assertTrue(False) + except: + self.assertTrue(True) + def suite(): suite = unittest.TestSuite() diff --git a/src/bpe_model.cc b/src/bpe_model.cc index f1a97f4..5d77baa 100644 --- a/src/bpe_model.cc +++ b/src/bpe_model.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License.! +#include "bpe_model.h" + #include #include #include @@ -19,7 +21,6 @@ #include #include -#include "bpe_model.h" #include "freelist.h" #include "third_party/absl/container/flat_hash_map.h" #include "util.h" @@ -132,6 +133,7 @@ std::vector> Model::SampleEncode( std::mt19937 *rand_gen = nullptr; auto skip_merge = [&]() { if (alpha <= 0.0) return false; + if (alpha >= 1.0) return true; if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator(); std::uniform_real_distribution<> gen(0.0, 1.0); return gen(*rand_gen) < alpha; diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 1c7fa6d..7227920 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -493,6 +493,12 @@ class SentencePieceProcessor { std::vector decode_extra_options_; }; +// Set seed value of random generator. +// Do not set static_cast(-1), +// as this seed is reserved for initializing from +// std::random_device. +void SetRandomGeneratorSeed(unsigned int seed); + #ifndef SWIG // IO related functions to absorb model formats. namespace io { diff --git a/src/spm_encode_main.cc b/src/spm_encode_main.cc index a04ca18..9b018f9 100644 --- a/src/spm_encode_main.cc +++ b/src/spm_encode_main.cc @@ -28,15 +28,16 @@ #include "trainer_interface.h" ABSL_FLAG(std::string, model, "", "model file name"); -ABSL_FLAG( - std::string, output_format, "piece", - "choose from piece, id, proto, nbest_piece, nbest_id, or nbest_proto"); +ABSL_FLAG(std::string, output_format, "piece", + "choose from piece, id, proto, nbest_piece, nbest_id, nbest_proto, + "sample_piece, sample_id or sample_proto."); ABSL_FLAG(std::string, input, "", "input filename"); ABSL_FLAG(std::string, output, "", "output filename"); ABSL_FLAG(std::string, extra_options, "", "':' separated encoder extra options, e.g., \"reverse:bos:eos\""); ABSL_FLAG(int32, nbest_size, 10, "NBest size"); ABSL_FLAG(double, alpha, 0.5, "Smoothing parameter for sampling mode."); +ABSL_FLAG(int32, random_seed, -1, "Seed value for random generator."); // Piece restriction with vocabulary file. // https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt diff --git a/src/util.cc b/src/util.cc index d3946e1..e9ef6e6 100644 --- a/src/util.cc +++ b/src/util.cc @@ -12,11 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License.! -#include - #include "util.h" +#include + namespace sentencepiece { +namespace { +constexpr unsigned int kDefaultSeed = static_cast(-1); +static unsigned int g_seed = kDefaultSeed; +} // namespace + +void SetRandomGeneratorSeed(unsigned int seed) { + if (seed != kDefaultSeed) g_seed = seed; +} + namespace string_util { // mblen sotres the number of bytes consumed after decoding. @@ -144,7 +153,8 @@ class RandomGeneratorStorage { std::mt19937 *Get() { auto *result = static_cast(pthread_getspecific(key_)); if (result == nullptr) { - result = new std::mt19937(std::random_device{}()); + result = new std::mt19937(g_seed == kDefaultSeed ? std::random_device{}() + : g_seed); pthread_setspecific(key_, result); } return result; @@ -162,7 +172,8 @@ std::mt19937 *GetRandomGenerator() { } #else std::mt19937 *GetRandomGenerator() { - thread_local static std::mt19937 mt(std::random_device{}()); + thread_local static std::mt19937 mt( + g_seed == kDefaultSeed ? std::random_device{}() : g_seed); return &mt; } #endif