mirror of
https://github.com/google/sentencepiece.git
synced 2025-01-08 18:26:38 +03:00
Get rid of dependency from tf_sentencepiece to sentencepiece
This commit is contained in:
parent
e4f5ed7d00
commit
de92f6ace0
@ -3,44 +3,51 @@
|
||||
|
||||
import itertools as it
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import tensorflow as tf
|
||||
import sentencepiece as spm
|
||||
import tf_sentencepiece as tfspm
|
||||
|
||||
class SentencePieceProcssorOpTest(unittest.TestCase):
|
||||
|
||||
def _getSentencePieceModelFile(self):
|
||||
return '../python/test/test_ja_model.model'
|
||||
return os.path.join('..', 'python', 'test', 'test_model.model')
|
||||
|
||||
def _getExpected(self, processor, reverse=False, add_bos=False,
|
||||
def _getPieceSize(self):
|
||||
return 1000
|
||||
|
||||
def _getExpected(self, reverse=False, add_bos=False,
|
||||
add_eos=False, padding=''):
|
||||
options = []
|
||||
# TF uses str(bytes) as a string representation.
|
||||
padding = padding.encode('utf8')
|
||||
sentences = [b'Hello world.', b'I have a pen.',
|
||||
b'I saw a girl with a telescope.']
|
||||
pieces = [[b'\xe2\x96\x81He', b'll', b'o', b'\xe2\x96\x81world', b'.'],
|
||||
[b'\xe2\x96\x81I', b'\xe2\x96\x81have', b'\xe2\x96\x81a',
|
||||
b'\xe2\x96\x81p', b'en', b'.'],
|
||||
[b'\xe2\x96\x81I', b'\xe2\x96\x81saw', b'\xe2\x96\x81a',
|
||||
b'\xe2\x96\x81girl', b'\xe2\x96\x81with',
|
||||
b'\xe2\x96\x81a', b'\xe2\x96\x81',
|
||||
b'te', b'le', b's', b'c', b'o', b'pe', b'.']]
|
||||
ids = [[151, 88, 21, 887, 6],
|
||||
[9, 76, 11, 68, 98, 6],
|
||||
[9, 459, 11, 939, 44, 11, 4, 142, 82, 8, 28, 21, 132, 6]]
|
||||
seq_len = [5, 6, 14]
|
||||
|
||||
if reverse:
|
||||
options.append('reverse')
|
||||
ids = [x[::-1] for x in ids]
|
||||
pieces = [x[::-1] for x in pieces]
|
||||
|
||||
if add_bos:
|
||||
options.append('bos')
|
||||
ids = [[1] + x for x in ids]
|
||||
pieces = [[b'<s>'] + x for x in pieces]
|
||||
seq_len = [x + 1 for x in seq_len]
|
||||
|
||||
if add_eos:
|
||||
options.append('eos')
|
||||
ids = [x + [2] for x in ids]
|
||||
pieces = [x + [b'</s>'] for x in pieces]
|
||||
seq_len = [x + 1 for x in seq_len]
|
||||
|
||||
processor.SetEncodeExtraOptions(':'.join(options))
|
||||
processor.SetDecodeExtraOptions(':'.join(options))
|
||||
|
||||
sentences = ['Hello world.', 'I have a pen.',
|
||||
'I saw a girl with a telescope.']
|
||||
pieces = []
|
||||
ids = []
|
||||
seq_len = []
|
||||
|
||||
for s in sentences:
|
||||
x = processor.EncodeAsPieces(s)
|
||||
y = processor.EncodeAsIds(s)
|
||||
pieces.append(x)
|
||||
ids.append(y)
|
||||
seq_len.append(len(x))
|
||||
self.assertEqual(len(x), len(y))
|
||||
|
||||
# padding
|
||||
max_len = max(seq_len)
|
||||
pieces = [x + [padding] * (max_len - len(x)) for x in pieces]
|
||||
ids = [x + [0] * (max_len - len(x)) for x in ids]
|
||||
@ -49,21 +56,16 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
|
||||
|
||||
def testGetPieceSize(self):
|
||||
sentencepiece_model_file = self._getSentencePieceModelFile()
|
||||
processor = spm.SentencePieceProcessor()
|
||||
processor.Load(sentencepiece_model_file)
|
||||
|
||||
with tf.Session():
|
||||
s = tfspm.piece_size(
|
||||
model_file=sentencepiece_model_file)
|
||||
self.assertEqual(s.eval(), processor.GetPieceSize())
|
||||
self.assertEqual(s.eval(), self._getPieceSize())
|
||||
|
||||
def testConvertPiece(self):
|
||||
sentencepiece_model_file = self._getSentencePieceModelFile()
|
||||
processor = spm.SentencePieceProcessor()
|
||||
processor.Load(sentencepiece_model_file)
|
||||
(sentences, expected_pieces,
|
||||
expected_ids, expected_seq_len) = self._getExpected(processor,
|
||||
padding='<unk>')
|
||||
expected_ids, expected_seq_len) = self._getExpected(padding='<unk>')
|
||||
|
||||
with tf.Session():
|
||||
ids_matrix = tfspm.piece_to_id(
|
||||
@ -97,15 +99,13 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
|
||||
|
||||
def testEncodeAndDecode(self):
|
||||
sentencepiece_model_file = self._getSentencePieceModelFile()
|
||||
processor = spm.SentencePieceProcessor()
|
||||
processor.Load(sentencepiece_model_file)
|
||||
|
||||
with tf.Session():
|
||||
for reverse, add_bos, add_eos in list(it.product(
|
||||
(True, False), repeat=3)):
|
||||
(sentences, expected_pieces,
|
||||
expected_ids, expected_seq_len) = self._getExpected(
|
||||
processor, reverse, add_bos, add_eos)
|
||||
reverse=reverse, add_bos=add_bos, add_eos=add_eos)
|
||||
|
||||
# Encode sentences into pieces/ids.
|
||||
s = tf.constant(sentences)
|
||||
@ -138,9 +138,7 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
|
||||
|
||||
def testSampleEncodeAndDecode(self):
|
||||
sentencepiece_model_file = self._getSentencePieceModelFile()
|
||||
processor = spm.SentencePieceProcessor()
|
||||
processor.Load(sentencepiece_model_file)
|
||||
sentences, _, _, _ = self._getExpected(processor)
|
||||
sentences, _, _, _ = self._getExpected()
|
||||
|
||||
with tf.Session():
|
||||
for n, a in [(-1, 0.1), (64, 0.1), (0, 0.0)]:
|
||||
@ -165,14 +163,12 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
|
||||
|
||||
def testEncodeAndDecodeSparse(self):
|
||||
sentencepiece_model_file = self._getSentencePieceModelFile()
|
||||
processor = spm.SentencePieceProcessor()
|
||||
processor.Load(sentencepiece_model_file)
|
||||
|
||||
with tf.Session():
|
||||
for reverse, add_bos, add_eos in list(it.product(
|
||||
(True, False), repeat=3)):
|
||||
(sentences, expected_pieces, expected_ids,
|
||||
_) = self._getExpected(processor, reverse, add_bos, add_eos)
|
||||
_) = self._getExpected(reverse, add_bos, add_eos)
|
||||
|
||||
# Encode sentences into sparse pieces/ids.
|
||||
s = tf.constant(sentences)
|
||||
@ -191,18 +187,16 @@ class SentencePieceProcssorOpTest(unittest.TestCase):
|
||||
|
||||
def testGetPieceType(self):
|
||||
sentencepiece_model_file = self._getSentencePieceModelFile()
|
||||
processor = spm.SentencePieceProcessor()
|
||||
processor.Load(sentencepiece_model_file)
|
||||
expected_is_unknown = []
|
||||
expected_is_control = []
|
||||
expected_is_unused = []
|
||||
ids = []
|
||||
|
||||
for i in range(processor.GetPieceSize()):
|
||||
for i in range(self._getPieceSize()):
|
||||
ids.append(i)
|
||||
expected_is_unknown.append(processor.IsUnknown(i))
|
||||
expected_is_control.append(processor.IsControl(i))
|
||||
expected_is_unused.append(processor.IsUnused(i))
|
||||
expected_is_unknown.append(i == 0)
|
||||
expected_is_control.append(i == 1 or i == 2)
|
||||
expected_is_unused.append(False)
|
||||
|
||||
with tf.Session():
|
||||
s = tf.constant(ids)
|
||||
|
Loading…
Reference in New Issue
Block a user