add set_min_log_level function to python to change the loglevel from python wrapper.

This commit is contained in:
Taku Kudo 2023-12-23 09:28:40 +00:00
parent bd3925a12e
commit 96aabaef96
8 changed files with 177 additions and 81 deletions

View File

@ -14,14 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.!
from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext as _build_ext
from setuptools.command.build_py import build_py as _build_py
import codecs
import os
import string
import subprocess
import sys
import os
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext as _build_ext
from setuptools.command.build_py import build_py as _build_py
sys.path.append(os.path.join('.', 'test'))
@ -94,6 +94,8 @@ class build_ext(_build_ext):
else:
cflags.append('-Wl,-strip-all')
libs.append('-Wl,-strip-all')
if sys.platform == 'linux':
libs.append('-Wl,-Bsymbolic')
print('## cflags={}'.format(' '.join(cflags)))
print('## libs={}'.format(' '.join(libs)))
ext.extra_compile_args = cflags

View File

@ -904,6 +904,9 @@ _sentencepiece.SentencePieceProcessor_swigregister(SentencePieceProcessor)
def SetRandomGeneratorSeed(seed):
return _sentencepiece.SetRandomGeneratorSeed(seed)
def SetMinLogLevel(v):
return _sentencepiece.SetMinLogLevel(v)
class SentencePieceTrainer(object):
thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")
@ -1039,6 +1042,7 @@ for m in [
_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
set_random_generator_seed = SetRandomGeneratorSeed
set_min_log_level = SetMinLogLevel
from ._version import __version__

View File

@ -1771,6 +1771,7 @@ for m in [
_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
set_random_generator_seed = SetRandomGeneratorSeed
set_min_log_level = SetMinLogLevel
from ._version import __version__

View File

@ -8429,6 +8429,36 @@ fail:
}
SWIGINTERN PyObject *_wrap_SetMinLogLevel(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
int arg1 ;
int val1 ;
int ecode1 = 0 ;
PyObject *swig_obj[1] ;
if (!args) SWIG_fail;
swig_obj[0] = args;
ecode1 = SWIG_AsVal_int(swig_obj[0], &val1);
if (!SWIG_IsOK(ecode1)) {
SWIG_exception_fail(SWIG_ArgError(ecode1), "in method '" "SetMinLogLevel" "', argument " "1"" of type '" "int""'");
}
arg1 = static_cast< int >(val1);
{
try {
sentencepiece::SetMinLogLevel(arg1);
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
resultobj = SWIG_Py_Void();
return resultobj;
fail:
return NULL;
}
SWIGINTERN PyObject *_wrap_SentencePieceTrainer__TrainFromString(PyObject *self, PyObject *args) {
PyObject *resultobj = 0;
absl::string_view arg1 ;
@ -8800,6 +8830,7 @@ static PyMethodDef SwigMethods[] = {
{ "SentencePieceProcessor_swigregister", SentencePieceProcessor_swigregister, METH_O, NULL},
{ "SentencePieceProcessor_swiginit", SentencePieceProcessor_swiginit, METH_VARARGS, NULL},
{ "SetRandomGeneratorSeed", _wrap_SetRandomGeneratorSeed, METH_O, NULL},
{ "SetMinLogLevel", _wrap_SetMinLogLevel, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromString", _wrap_SentencePieceTrainer__TrainFromString, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromMap", _wrap_SentencePieceTrainer__TrainFromMap, METH_O, NULL},
{ "SentencePieceTrainer__TrainFromMap2", _wrap_SentencePieceTrainer__TrainFromMap2, METH_VARARGS, NULL},

View File

@ -15,14 +15,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.!
from collections import defaultdict
import io
import sentencepiece as spm
import unittest
import sys
import os
import pickle
from collections import defaultdict
import sys
import unittest
import sentencepiece as spm
print('VERSION={}'.format(spm.__version__))
@ -39,7 +38,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.jasp_ = spm.SentencePieceProcessor()
self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model')))
self.assertTrue(
self.jasp_.Load(os.path.join('test', 'test_ja_model.model')))
self.jasp_.Load(os.path.join('test', 'test_ja_model.model'))
)
with open(os.path.join('test', 'test_model.model'), 'rb') as f:
self.assertTrue(self.sp_.LoadFromSerializedProto(f.read()))
with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f:
@ -83,14 +83,18 @@ class TestSentencepieceProcessor(unittest.TestCase):
for n in range(100):
self.assertEqual(
text,
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5)))
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5)),
)
self.assertEqual(
text,
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5)))
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5)),
)
self.assertEqual(
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5)))
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5))
)
self.assertEqual(
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5)))
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5))
)
ids2 = self.sp_.encode_as_ids(text)
pieces3 = self.sp_.encode_as_pieces(text)
@ -104,21 +108,28 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual(
text,
self.sp_.decode_pieces(
self.sp_.sample_encode_as_pieces(text, 64, 0.5)))
self.sp_.sample_encode_as_pieces(text, 64, 0.5)
),
)
self.assertEqual(
text,
self.sp_.decode_pieces(
self.sp_.sample_encode_as_pieces(text, -1, 0.5)))
self.sp_.sample_encode_as_pieces(text, -1, 0.5)
),
)
self.assertEqual(
text,
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)))
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)),
)
self.assertEqual(
text,
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)))
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)),
)
self.assertEqual(
self.sp_.calculate_entropy(text, 0.1),
self.sp_.CalculateEntropy(text, 0.1))
self.sp_.CalculateEntropy(text, 0.1),
)
def test_ja_load(self):
self.assertEqual(8000, self.jasp_.GetPieceSize())
@ -155,11 +166,15 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual(
text,
self.jasp_.DecodePieces(
self.jasp_.SampleEncodeAsPieces(text, 64, 0.5)))
self.jasp_.SampleEncodeAsPieces(text, 64, 0.5)
),
)
self.assertEqual(
text,
self.jasp_.DecodePieces(
self.jasp_.SampleEncodeAsPieces(text, -1, 0.5)))
self.jasp_.SampleEncodeAsPieces(text, -1, 0.5)
),
)
ids2 = self.jasp_.encode_as_ids(text)
pieces3 = self.jasp_.encode_as_pieces(text)
@ -173,20 +188,27 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual(
text,
self.jasp_.decode_pieces(
self.jasp_.sample_encode_as_pieces(text, 64, 0.5)))
self.jasp_.sample_encode_as_pieces(text, 64, 0.5)
),
)
self.assertEqual(
text,
self.jasp_.decode_pieces(
self.jasp_.sample_encode_as_pieces(text, -1, 0.5)))
self.jasp_.sample_encode_as_pieces(text, -1, 0.5)
),
)
self.assertEqual(
self.jasp_.calculate_entropy(text, 0.1),
self.jasp_.CalculateEntropy(text, 0.1))
self.jasp_.CalculateEntropy(text, 0.1),
)
def test_train(self):
spm.SentencePieceTrainer.Train('--input=' +
os.path.join(data_dir, 'botchan.txt') +
' --model_prefix=m --vocab_size=1000')
spm.SentencePieceTrainer.Train(
'--input='
+ os.path.join(data_dir, 'botchan.txt')
+ ' --model_prefix=m --vocab_size=1000'
)
sp = spm.SentencePieceProcessor()
sp.Load('m.model')
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
@ -195,9 +217,11 @@ class TestSentencepieceProcessor(unittest.TestCase):
sp.DecodeIds(sp.EncodeAsIds(line))
def test_train_iterator(self):
spm.SentencePieceTrainer.Train('--input=' +
os.path.join(data_dir, 'botchan.txt') +
' --model_prefix=m --vocab_size=1000')
spm.SentencePieceTrainer.Train(
'--input='
+ os.path.join(data_dir, 'botchan.txt')
+ ' --model_prefix=m --vocab_size=1000'
)
# Load as 'rb' for Python3.5/2.7.
os1 = io.BytesIO()
os2 = io.BytesIO()
@ -207,32 +231,38 @@ class TestSentencepieceProcessor(unittest.TestCase):
input=os.path.join(data_dir, 'botchan.txt'),
model_prefix='m',
vocab_size=1000,
logstream=open(os.devnull, 'w'))
logstream=open(os.devnull, 'w'),
)
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is1:
spm.SentencePieceTrainer.train(
sentence_iterator=is1,
model_prefix='m',
vocab_size=1000,
logstream=open(os.devnull, 'w'))
logstream=open(os.devnull, 'w'),
)
spm.SentencePieceTrainer.train(
input=os.path.join(data_dir, 'botchan.txt'),
model_writer=os1,
vocab_size=1000,
logstream=open(os.devnull, 'w'))
logstream=open(os.devnull, 'w'),
)
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is2:
spm.SentencePieceTrainer.train(
sentence_iterator=is2,
model_writer=os2,
vocab_size=1000,
logstream=open(os.devnull, 'w'))
logstream=open(os.devnull, 'w'),
)
sp1 = spm.SentencePieceProcessor(model_proto=os1.getvalue())
sp2 = spm.SentencePieceProcessor(model_proto=os2.getvalue())
self.assertEqual([sp1.id_to_piece(i) for i in range(sp1.get_piece_size())],
[sp2.id_to_piece(i) for i in range(sp2.get_piece_size())])
self.assertEqual(
[sp1.id_to_piece(i) for i in range(sp1.get_piece_size())],
[sp2.id_to_piece(i) for i in range(sp2.get_piece_size())],
)
def test_train_kwargs(self):
# suppress logging (redirect to /dev/null)
@ -241,7 +271,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
model_prefix='m',
vocab_size=1002,
user_defined_symbols=['foo', 'bar', ',', ' ', '\t', '\b', '\n', '\r'],
logstream=open(os.devnull, 'w'))
logstream=open(os.devnull, 'w'),
)
sp = spm.SentencePieceProcessor()
sp.Load('m.model')
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
@ -268,7 +299,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
y1 = self.sp_.encode(text, out_type='serialized_proto')
y2 = self.sp_.encode(
text, enable_sampling=True, out_type='serialized_proto')
text, enable_sampling=True, out_type='serialized_proto'
)
y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10)
y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto')
y5 = self.sp_.decode([20, 30], out_type='serialized_proto')
@ -372,7 +404,7 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual([x.piece for x in s1.pieces], v2)
self.assertEqual(text, s1.text)
surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces]
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces]
surfaces2 = [x.surface for x in s1.pieces]
self.assertEqual(surfaces1, surfaces2)
@ -393,15 +425,18 @@ class TestSentencepieceProcessor(unittest.TestCase):
for i in range(len(s3.nbests)):
self.assertEqual(text, s3.nbests[i].text)
self.assertEqual(
self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text)
self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text
)
# slice
self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces)))
self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests)))
# Japanese offset
s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123')
surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces]
s1 = self.jasp_.EncodeAsImmutableProto(
'吾輩は猫である。Hello world. ABC 123'
)
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces]
surfaces2 = [x.surface for x in s1.pieces]
self.assertEqual(surfaces1, surfaces2)
@ -415,7 +450,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
def test_new_api(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model'))
model_file=os.path.join('test', 'test_model.model')
)
text = 'hello world'
text2 = 'Tokyo'
ids = self.sp_.EncodeAsIds(text)
@ -512,7 +548,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
model_file=os.path.join('test', 'test_model.model'),
add_bos=True,
add_eos=True,
out_type=str)
out_type=str,
)
text = 'hello world'
pieces = ['<s>'] + self.sp_.EncodeAsPieces(text) + ['</s>']
self.assertEqual(pieces, sp.encode(text))
@ -540,13 +577,17 @@ class TestSentencepieceProcessor(unittest.TestCase):
++ids2[out]
self.assertEqual(len(ids2), 1)
out = sp.encode(['hello world', 'this is a test'],
out_type=out_type,
enable_sampling=True)
out = sp.encode(
['hello world', 'this is a test'],
out_type=out_type,
enable_sampling=True,
)
self.assertEqual(len(out), 2)
out = sp.encode(['hello world', 'this is a test'],
out_type=out_type,
enable_sampling=False)
out = sp.encode(
['hello world', 'this is a test'],
out_type=out_type,
enable_sampling=False,
)
self.assertEqual(len(out), 2)
def test_nbest(self):
@ -556,8 +597,9 @@ class TestSentencepieceProcessor(unittest.TestCase):
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
results = sp.nbest_encode(text, nbest_size=10, out_type=out_type)
self.assertEqual(results,
sp.NBestEncode(text, nbest_size=10, out_type=out_type))
self.assertEqual(
results, sp.NBestEncode(text, nbest_size=10, out_type=out_type)
)
if out_type in [str, int]:
for n in results:
@ -570,7 +612,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type)
self.assertEqual(
results,
sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type))
sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type),
)
self.assertEqual(len(results), 2)
if out_type in [str, int]:
@ -591,16 +634,20 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type=str),
sp.nbest_encode_as_pieces(text, nbest_size=10))
sp.nbest_encode_as_pieces(text, nbest_size=10),
)
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type=int),
sp.nbest_encode_as_ids(text, nbest_size=10))
sp.nbest_encode_as_ids(text, nbest_size=10),
)
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type='serialized_proto'),
sp.nbest_encode_as_serialized_proto(text, nbest_size=10))
sp.nbest_encode_as_serialized_proto(text, nbest_size=10),
)
self.assertEqual(
sp.nbest_encode(text, nbest_size=10, out_type='immutable_proto'),
sp.nbest_encode_as_immutable_proto(text, nbest_size=10))
sp.nbest_encode_as_immutable_proto(text, nbest_size=10),
)
def test_sample_and_score(self):
sp = self.sp_
@ -608,22 +655,22 @@ class TestSentencepieceProcessor(unittest.TestCase):
text2 = 'I have a pen.'
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
results = sp.sample_encode_and_score(
text, wor=True, num_samples=10, out_type=out_type)
text, wor=True, num_samples=10, out_type=out_type
)
results = sp.SampleEncodeAndScore(
text, wor=False, num_samples=10, out_type=out_type)
text, wor=False, num_samples=10, out_type=out_type
)
if out_type in [str, int]:
for n in results:
self.assertEqual(sp.decode(n[0]), text)
results = sp.sample_encode_and_score([text, text2],
wor=True,
num_samples=10,
out_type=out_type)
results = sp.SampleEncodeAndScore([text, text2],
wor=True,
num_samples=10,
out_type=out_type)
results = sp.sample_encode_and_score(
[text, text2], wor=True, num_samples=10, out_type=out_type
)
results = sp.SampleEncodeAndScore(
[text, text2], wor=True, num_samples=10, out_type=out_type
)
if out_type in [str, int]:
for n in results[0]:
@ -639,8 +686,14 @@ class TestSentencepieceProcessor(unittest.TestCase):
def test_valid_range(self):
size = self.sp_.piece_size()
funcs = [
'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused', 'IsByte',
'DecodeIds', 'DecodeIdsAsSerializedProto'
'IdToPiece',
'GetScore',
'IsUnknown',
'IsControl',
'IsUnused',
'IsByte',
'DecodeIds',
'DecodeIdsAsSerializedProto',
]
for m in funcs:
getattr(self.sp_, m)([10, 20, 30])
@ -654,7 +707,8 @@ class TestSentencepieceProcessor(unittest.TestCase):
def test_batch(self):
sp = spm.SentencePieceProcessor(
model_file=os.path.join('test', 'test_model.model'))
model_file=os.path.join('test', 'test_model.model')
)
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
texts = file.readlines()
@ -700,6 +754,12 @@ class TestSentencepieceProcessor(unittest.TestCase):
self.assertEqual(id1, id2)
def test_global_params(self):
spm.SetRandomGeneratorSeed(0)
spm.SetMinLogLevel(2)
spm.set_random_generator_seed(1)
spm.set_min_log_level(3)
def suite():
suite = unittest.TestSuite()

View File

@ -74,17 +74,9 @@ char (&ArraySizeHelper(const T (&array)[N]))[N];
#endif
namespace sentencepiece {
#ifdef OS_WIN
namespace win32 {
std::wstring Utf8ToWide(const absl::string_view input);
} // namespace win32
#endif
#ifdef IS_BIG_ENDIAN
namespace util {
inline uint32 Swap32(uint32 x) { return __builtin_bswap32(x); }
} // namespace util
#endif
namespace error {

View File

@ -431,19 +431,19 @@ class SentencePieceProcessor {
#define DEFINE_SPP_DIRECT_FUNC_IMPL(FuncName, OutType, ...) \
OutType output; \
const auto status = FuncName(__VA_ARGS__, &output); \
SPP_SWIG_CHECK_AND_THROW; \
SPP_SWIG_CHECK_AND_THROW; \
return output;
#define DEFINE_SPP_SERIALIZED_PROTO_IMPL(FuncName, OutType, ...) \
OutType output; \
const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
SPP_SWIG_CHECK_AND_THROW; \
SPP_SWIG_CHECK_AND_THROW; \
return output.SerializeAsString();
#define DEFINE_SPP_IMMUTABLE_PROTO_IMPL(FuncName, OutType, ...) \
OutType output; \
const auto status = FuncName(__VA_ARGS__, output.mutable_proto()); \
SPP_SWIG_CHECK_AND_THROW; \
SPP_SWIG_CHECK_AND_THROW; \
return output;
//////////////////////////////////////////////////////////////
@ -709,6 +709,10 @@ class SentencePieceProcessor {
// std::random_device.
void SetRandomGeneratorSeed(unsigned int seed);
// Set the global log level. The default loglevel is 0.
// The log is emitted only when min_log_level >= output_log_level.
void SetMinLogLevel(int v);
// IO related functions to absorb model formats.
namespace io {
// Loads `model_proto` from `filename`.

View File

@ -43,6 +43,8 @@ int GetMinLogLevel() { return g_minloglevel.load(); }
void SetMinLogLevel(int v) { g_minloglevel.store(v); }
} // namespace logging
void SetMinLogLevel(int v) { logging::SetMinLogLevel(v); }
namespace string_util {
// mblen sotres the number of bytes consumed after decoding.