mirror of
https://github.com/rsennrich/subword-nmt.git
synced 2024-11-27 12:42:07 +03:00
Allow passing in a word - count file instead of iterating through the whole dataset
This commit is contained in:
parent
4c54e1df2e
commit
f37902dec6
15
learn_bpe.py
15
learn_bpe.py
@ -56,16 +56,21 @@ def create_parser():
|
||||
parser.add_argument(
|
||||
'--verbose', '-v', action="store_true",
|
||||
help="verbose mode.")
|
||||
|
||||
parser.add_argument('--is_dict', '-is_dict', type=bool, default=False,
|
||||
help="Specify this argument if the input file is a dictionary where each line contains a word count pair")
|
||||
return parser
|
||||
|
||||
def get_vocabulary(fobj):
|
||||
def get_vocabulary(fobj, is_dict=False):
|
||||
"""Read text and return dictionary that encodes vocabulary
|
||||
"""
|
||||
vocab = Counter()
|
||||
for line in fobj:
|
||||
for word in line.split():
|
||||
vocab[word] += 1
|
||||
if is_dict:
|
||||
word_count = line.strip().split()
|
||||
vocab[word_count[0]] = int(word_count[1])
|
||||
else:
|
||||
for word in line.split():
|
||||
vocab[word] += 1
|
||||
return vocab
|
||||
|
||||
def update_pair_statistics(pair, changed, stats, indices):
|
||||
@ -189,7 +194,7 @@ if __name__ == '__main__':
|
||||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
|
||||
vocab = get_vocabulary(args.input)
|
||||
vocab = get_vocabulary(args.input, is_dict = args.is_dict)
|
||||
vocab = dict([(tuple(x)+('</w>',) ,y) for (x,y) in vocab.items()])
|
||||
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user