From f296824f4013dc28b471c47d7779547460afc7f0 Mon Sep 17 00:00:00 2001 From: Vladimir Karpukhin Date: Thu, 28 Feb 2019 09:15:35 -0800 Subject: [PATCH] Move string line encoding logic from tokenizer to Dictionary (unified diff). (#541) Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/541 Just a combo of a stacked pair D14057943 & D14176011, Made this as a separete diff cause there seems to be some issue with porting a stacked change into github repo Differential Revision: D14251048 fbshipit-source-id: 0a47f534a69d6ab2ebe035fba40fd51748cccfb8 --- docs/tutorial_classifying_names.rst | 10 +-- fairseq/binarizer.py | 67 +++++++++++++++ fairseq/data/data_utils.py | 8 ++ fairseq/data/dictionary.py | 100 ++++++++++++++++++---- fairseq/data/indexed_dataset.py | 6 +- fairseq/tasks/fairseq_task.py | 3 +- fairseq/tokenizer.py | 126 ---------------------------- fairseq/utils.py | 2 +- generate.py | 3 +- interactive.py | 2 +- preprocess.py | 62 +++++++------- score.py | 6 +- tests/test_dictionary.py | 5 +- 13 files changed, 204 insertions(+), 196 deletions(-) create mode 100644 fairseq/binarizer.py diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index ebdb736a9..cdec9b0a1 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -209,7 +209,6 @@ following contents:: from fairseq.data import Dictionary, LanguagePairDataset from fairseq.tasks import FairseqTask, register_task - from fairseq.tokenizer import Tokenizer @register_task('simple_classification') @@ -253,8 +252,8 @@ following contents:: sentence = line.strip() # Tokenize the sentence, splitting on spaces - tokens = Tokenizer.tokenize( - sentence, self.input_vocab, add_if_not_exist=False, + tokens = self.input_vocab.encode_line( + sentence, add_if_not_exist=False, ) sentences.append(tokens) @@ -356,7 +355,6 @@ Finally we can write a short script to evaluate our model on new inputs. Create a new file named :file:`eval_classifier.py` with the following contents:: from fairseq import data, options, tasks, utils - from fairseq.tokenizer import Tokenizer # Parse command-line arguments for generation parser = options.get_generation_parser(default_task='simple_classification') @@ -375,8 +373,8 @@ a new file named :file:`eval_classifier.py` with the following contents:: # Tokenize into characters chars = ' '.join(list(sentence.strip())) - tokens = Tokenizer.tokenize( - chars, task.source_dictionary, add_if_not_exist=False, + tokens = task.source_dictionary.encode_line( + chars, add_if_not_exist=False, ) # Build mini-batch to feed to the model diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py new file mode 100644 index 000000000..5130c4e12 --- /dev/null +++ b/fairseq/binarizer.py @@ -0,0 +1,67 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from collections import Counter +import os + +from fairseq.tokenizer import tokenize_line + + +def safe_readline(f): + pos = f.tell() + while True: + try: + return f.readline() + except UnicodeDecodeError: + pos -= 1 + f.seek(pos) # search where this character begins + + +class Binarizer: + + @staticmethod + def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False, + offset=0, end=-1): + nseq, ntok = 0, 0 + replaced = Counter() + + def replaced_consumer(word, idx): + if idx == dict.unk_index and word != dict.unk_word: + replaced.update([word]) + + with open(filename, 'r', encoding='utf-8') as f: + f.seek(offset) + # next(f) breaks f.tell(), hence readline() must be used + line = safe_readline(f) + while line: + if end > 0 and f.tell() > end: + break + ids = dict.encode_line( + line=line, + line_tokenizer=tokenize, + add_if_not_exist=False, + consumer=replaced_consumer, + append_eos=append_eos, + reverse_order=reverse_order, + ) + nseq += 1 + ntok += len(ids) + consumer(ids) + line = f.readline() + return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} + + @staticmethod + def find_offsets(filename, num_chunks): + with open(filename, 'r', encoding='utf-8') as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_chunks + offsets = [0 for _ in range(num_chunks + 1)] + for i in range(1, num_chunks): + f.seek(chunk_size * i) + safe_readline(f) + offsets[i] = f.tell() + return offsets diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 1adc952ac..85a081a5d 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -213,3 +213,11 @@ def batch_by_size( if len(batch) > 0: yield batch + + +def process_bpe_symbol(sentence: str, bpe_symbol: str): + if bpe_symbol == 'sentencepiece': + sentence = sentence.replace('\u2581', ' ').strip() + elif bpe_symbol is not None: + sentence = (sentence + ' ').replace(bpe_symbol, '').rstrip() + return sentence diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 265be5a84..90be5d3c3 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -6,10 +6,15 @@ # can be found in the PATENTS file in the same directory. from collections import Counter +from multiprocessing import Pool import os import torch +from fairseq.tokenizer import tokenize_line +from fairseq.binarizer import safe_readline +from fairseq.data import data_utils + class Dictionary(object): """A mapping from symbols to consecutive integers""" @@ -57,14 +62,8 @@ class Dictionary(object): else: return self[i] - if bpe_symbol == 'sentencepiece': - sent = ''.join(token_string(i) for i in tensor if i != self.eos()) - sent = sent.replace('\u2581', ' ').strip() - else: - sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) - if bpe_symbol is not None and bpe_symbol != 'sentencepiece': - sent = (sent + ' ').replace(bpe_symbol, '').rstrip() - return sent + sent = ''.join(token_string(i) for i in tensor if i != self.eos()) + return data_utils.process_bpe_symbol(sent, bpe_symbol) def unk_string(self, escape=False): """Return unknown string, optionally escaped as: <>""" @@ -181,31 +180,104 @@ class Dictionary(object): "rebuild the dataset".format(f)) d = cls() - for line in f.readlines(): + lines = f.readlines() + indices_start_line = d._load_meta(lines) + for line in lines[indices_start_line:]: idx = line.rfind(' ') if idx == -1: raise ValueError("Incorrect dictionary format, expected ' '") word = line[:idx] - count = int(line[idx+1:]) + count = int(line[idx + 1:]) d.indices[word] = len(d.symbols) d.symbols.append(word) d.count.append(count) return d - def save(self, f): - """Stores dictionary into a text file""" + def _save(self, f, kv_iterator): if isinstance(f, str): os.makedirs(os.path.dirname(f), exist_ok=True) with open(f, 'w', encoding='utf-8') as fd: return self.save(fd) - for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]): - print('{} {}'.format(symbol, count), file=f) + for k, v in kv_iterator: + print('{} {}'.format(k, v), file=f) + + def _get_meta(self): + return [], [] + + def _load_meta(self, lines): + return 0 + + def save(self, f): + """Stores dictionary into a text file""" + ex_keys, ex_vals = self._get_meta() + self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:])) def dummy_sentence(self, length): t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() t[-1] = self.eos() return t + def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True, + consumer=None, append_eos=True, reverse_order=False): + words = line_tokenizer(line) + if reverse_order: + words = list(reversed(words)) + nwords = len(words) + ids = torch.IntTensor(nwords + 1 if append_eos else nwords) + + for i, word in enumerate(words): + if add_if_not_exist: + idx = self.add_symbol(word) + else: + idx = self.index(word) + if consumer is not None: + consumer(word, idx) + ids[i] = idx + if append_eos: + ids[nwords] = self.eos_index + return ids + + @staticmethod + def _add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1): + counter = Counter() + with open(filename, 'r', encoding='utf-8') as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_workers + offset = worker_id * chunk_size + end = offset + chunk_size + f.seek(offset) + if offset > 0: + safe_readline(f) # drop first incomplete line + line = f.readline() + while line: + for word in tokenize(line): + counter.update([word]) + counter.update([eos_word]) + if f.tell() > end: + break + line = f.readline() + return counter + + @staticmethod + def add_file_to_dictionary(filename, dict, tokenize, num_workers): + def merge_result(counter): + for w, c in counter.items(): + dict.add_symbol(w, c) + + if num_workers > 1: + pool = Pool(processes=num_workers) + results = [] + for worker_id in range(num_workers): + results.append(pool.apply_async( + Dictionary._add_file_to_dictionary_single_worker, + (filename, tokenize, dict.eos_word, worker_id, num_workers) + )) + pool.close() + pool.join() + for r in results: + merge_result(r.get()) + else: + merge_result(Dictionary._add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word)) class TruncatedDictionary(object): diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 970c33ed4..1c7fcfbf9 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -11,8 +11,6 @@ import struct import numpy as np import torch -from fairseq.tokenizer import Tokenizer - def read_longs(f, n): a = np.empty(n, dtype=np.int64) @@ -171,8 +169,8 @@ class IndexedRawTextDataset(torch.utils.data.Dataset): with open(path, 'r', encoding='utf-8') as f: for line in f: self.lines.append(line.strip('\n')) - tokens = Tokenizer.tokenize( - line, dictionary, add_if_not_exist=False, + tokens = dictionary.encode_line( + line, add_if_not_exist=False, append_eos=self.append_eos, reverse_order=self.reverse_order, ).long() self.tokens_list.append(tokens) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index af2779dfa..3aefff2d9 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -9,7 +9,6 @@ import torch from fairseq import tokenizer from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary -from fairseq.tokenizer import Tokenizer class FairseqTask(object): @@ -52,7 +51,7 @@ class FairseqTask(object): """ d = Dictionary() for filename in filenames: - Tokenizer.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) + Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) return d diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py index d6525479e..ca368db89 100644 --- a/fairseq/tokenizer.py +++ b/fairseq/tokenizer.py @@ -5,13 +5,8 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -from collections import Counter -from multiprocessing import Pool -import os import re -import torch - SPACE_NORMALIZER = re.compile(r"\s+") @@ -19,124 +14,3 @@ def tokenize_line(line): line = SPACE_NORMALIZER.sub(" ", line) line = line.strip() return line.split() - - -def safe_readline(f): - pos = f.tell() - while True: - try: - return f.readline() - except UnicodeDecodeError: - pos -= 1 - f.seek(pos) # search where this character begins - - -class Tokenizer: - - @staticmethod - def add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1): - counter = Counter() - with open(filename, 'r', encoding='utf-8') as f: - size = os.fstat(f.fileno()).st_size - chunk_size = size // num_workers - offset = worker_id * chunk_size - end = offset + chunk_size - f.seek(offset) - if offset > 0: - safe_readline(f) # drop first incomplete line - line = f.readline() - while line: - for word in tokenize(line): - counter.update([word]) - counter.update([eos_word]) - if f.tell() > end: - break - line = f.readline() - return counter - - @staticmethod - def add_file_to_dictionary(filename, dict, tokenize, num_workers): - def merge_result(counter): - for w, c in counter.items(): - dict.add_symbol(w, c) - if num_workers > 1: - pool = Pool(processes=num_workers) - results = [] - for worker_id in range(num_workers): - results.append(pool.apply_async( - Tokenizer.add_file_to_dictionary_single_worker, - (filename, tokenize, dict.eos_word, worker_id, num_workers) - )) - pool.close() - pool.join() - for r in results: - merge_result(r.get()) - else: - merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word)) - - @staticmethod - def binarize( - filename, dict, consumer, tokenize=tokenize_line, append_eos=True, - reverse_order=False, offset=0, end=-1, - ): - nseq, ntok = 0, 0 - replaced = Counter() - - def replaced_consumer(word, idx): - if idx == dict.unk_index and word != dict.unk_word: - replaced.update([word]) - - with open(filename, 'r', encoding='utf-8') as f: - f.seek(offset) - # next(f) breaks f.tell(), hence readline() must be used - line = safe_readline(f) - while line: - if end > 0 and f.tell() > end: - break - ids = Tokenizer.tokenize( - line=line, - dict=dict, - tokenize=tokenize, - add_if_not_exist=False, - consumer=replaced_consumer, - append_eos=append_eos, - reverse_order=reverse_order, - ) - nseq += 1 - ntok += len(ids) - consumer(ids) - line = f.readline() - return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced} - - @staticmethod - def find_offsets(filename, num_chunks): - with open(filename, 'r', encoding='utf-8') as f: - size = os.fstat(f.fileno()).st_size - chunk_size = size // num_chunks - offsets = [0 for _ in range(num_chunks + 1)] - for i in range(1, num_chunks): - f.seek(chunk_size * i) - safe_readline(f) - offsets[i] = f.tell() - return offsets - - @staticmethod - def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, - consumer=None, append_eos=True, reverse_order=False): - words = tokenize(line) - if reverse_order: - words = list(reversed(words)) - nwords = len(words) - ids = torch.IntTensor(nwords + 1 if append_eos else nwords) - - for i, word in enumerate(words): - if add_if_not_exist: - idx = dict.add_symbol(word) - else: - idx = dict.index(word) - if consumer is not None: - consumer(word, idx) - ids[i] = idx - if append_eos: - ids[nwords] = dict.eos_index - return ids diff --git a/fairseq/utils.py b/fairseq/utils.py index 34e5d5be4..044ed1a2e 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -304,7 +304,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic if align_dict is not None or remove_bpe is not None: # Convert back to tokens for evaluating with unk replacement or without BPE # Note that the dictionary can be modified inside the method. - hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True) + hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True) return hypo_tokens, hypo_str, alignment diff --git a/generate.py b/generate.py index b6467637f..343957709 100644 --- a/generate.py +++ b/generate.py @@ -165,8 +165,7 @@ def main(args): if has_target and i == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE - target_tokens = tokenizer.Tokenizer.tokenize( - target_str, tgt_dict, add_if_not_exist=True) + target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, hypo_str) else: diff --git a/interactive.py b/interactive.py index 8408e122f..41423bdb7 100644 --- a/interactive.py +++ b/interactive.py @@ -38,7 +38,7 @@ def buffered_read(input, buffer_size): def make_batches(lines, args, task, max_positions): tokens = [ - tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long() + task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long() for src_str in lines ] lengths = torch.LongTensor([t.numel() for t in tokens]) diff --git a/preprocess.py b/preprocess.py index 4fb8a821a..aacadef92 100644 --- a/preprocess.py +++ b/preprocess.py @@ -11,15 +11,15 @@ Data pre-processing: build vocabularies and binarize training data. from collections import Counter from itertools import zip_longest -import os -import shutil from fairseq import options, tasks from fairseq.data import indexed_dataset -from fairseq.tokenizer import Tokenizer +from fairseq.binarizer import Binarizer +from fairseq.utils import import_user_module from multiprocessing import Pool -from fairseq.utils import import_user_module +import os +import shutil def main(args): @@ -95,9 +95,8 @@ def main(args): if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang)) - def make_binary_dataset(input_prefix, output_prefix, lang, num_workers): - dict = task.load_dictionary(dict_path(lang)) - print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1)) + def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): + print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1)) n_seq_tok = [0, 0] replaced = Counter() @@ -109,7 +108,7 @@ def main(args): input_file = "{}{}".format( input_prefix, ("." + lang) if lang is not None else "" ) - offsets = Tokenizer.find_offsets(input_file, num_workers) + offsets = Binarizer.find_offsets(input_file, num_workers) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) @@ -120,13 +119,13 @@ def main(args): ( args, input_file, - dict, + vocab, prefix, lang, offsets[worker_id], - offsets[worker_id + 1], + offsets[worker_id + 1] ), - callback=merge_result, + callback=merge_result ) pool.close() @@ -134,8 +133,9 @@ def main(args): dataset_dest_file(args, output_prefix, lang, "bin") ) merge_result( - Tokenizer.binarize( - input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1] + Binarizer.binarize( + input_file, vocab, lambda t: ds.add_item(t), + offset=0, end=offsets[1] ) ) if num_workers > 1: @@ -156,13 +156,13 @@ def main(args): n_seq_tok[0], n_seq_tok[1], 100 * sum(replaced.values()) / n_seq_tok[1], - dict.unk_word, + vocab.unk_word, ) ) - def make_dataset(input_prefix, output_prefix, lang, num_workers=1): + def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.output_format == "binary": - make_binary_dataset(input_prefix, output_prefix, lang, num_workers) + make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) elif args.output_format == "raw": # Copy original text file to destination folder output_text_file = dest_path( @@ -171,21 +171,21 @@ def main(args): ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) - def make_all(lang): + def make_all(lang, vocab): if args.trainpref: - make_dataset(args.trainpref, "train", lang, num_workers=args.workers) + make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" - make_dataset(validpref, outprefix, lang) + make_dataset(vocab, validpref, outprefix, lang) if args.testpref: for k, testpref in enumerate(args.testpref.split(",")): outprefix = "test{}".format(k) if k > 0 else "test" - make_dataset(testpref, outprefix, lang) + make_dataset(vocab, testpref, outprefix, lang) - make_all(args.source_lang) + make_all(args.source_lang, src_dict) if target: - make_all(args.target_lang) + make_all(args.target_lang, tgt_dict) print("| Wrote preprocessed data to {}".format(args.destdir)) @@ -198,8 +198,8 @@ def main(args): with open(src_file_name, "r", encoding='utf-8') as src_file: with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: for a, s, t in zip_longest(align_file, src_file, tgt_file): - si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False) - ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False) + si = src_dict.encode_line(s, add_if_not_exist=False) + ti = tgt_dict.encode_line(t, add_if_not_exist=False) ai = list(map(lambda x: tuple(x.split("-")), a.split())) for sai, tai in ai: srcidx = si[int(sai)] @@ -232,7 +232,7 @@ def main(args): print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) -def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos=True): +def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True): ds = indexed_dataset.IndexedDatasetBuilder( dataset_dest_file(args, output_prefix, lang, "bin") ) @@ -240,14 +240,8 @@ def binarize(args, filename, dict, output_prefix, lang, offset, end, append_eos= def consumer(tensor): ds.add_item(tensor) - res = Tokenizer.binarize( - filename, - dict, - consumer, - offset=offset, - end=end, - append_eos=append_eos - ) + res = Binarizer.binarize(filename, vocab, consumer, append_eos=append_eos, + offset=offset, end=end) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) return res @@ -266,7 +260,7 @@ def dataset_dest_file(args, output_prefix, lang, extension): def get_offsets(input_file, num_workers): - return Tokenizer.find_offsets(input_file, num_workers) + return Binarizer.find_offsets(input_file, num_workers) def merge_files(files, outpath): diff --git a/score.py b/score.py index 184b431ff..e4c098666 100644 --- a/score.py +++ b/score.py @@ -13,7 +13,7 @@ import argparse import os import sys -from fairseq import bleu, tokenizer +from fairseq import bleu from fairseq.data import dictionary @@ -62,8 +62,8 @@ def main(): with open(args.ref) as fdref: scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): - sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) - ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) + sys_tok = dict.encode_line(sys_tok) + ref_tok = dict.encode_line(ref_tok) scorer.add(ref_tok, sys_tok) print(scorer.result_string(args.order)) diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index 572f78a33..4df94b0ce 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -11,7 +11,6 @@ import unittest import torch from fairseq.data import Dictionary -from fairseq.tokenizer import Tokenizer class TestDictionary(unittest.TestCase): @@ -39,12 +38,12 @@ class TestDictionary(unittest.TestCase): # build dictionary d = Dictionary() for line in txt: - Tokenizer.tokenize(line, d, add_if_not_exist=True) + d.encode_line(line, add_if_not_exist=True) def get_ids(dictionary): ids = [] for line in txt: - ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False)) + ids.append(dictionary.encode_line(line, add_if_not_exist=False)) return ids def assertMatch(ids, ref_ids):