mirror of
https://github.com/rsennrich/subword-nmt.git
synced 2024-11-27 02:53:55 +03:00
82 lines
2.6 KiB
Python
Executable File
82 lines
2.6 KiB
Python
Executable File
#! /usr/bin/env python
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import sys
|
|
import inspect
|
|
import warnings
|
|
import argparse
|
|
import codecs
|
|
|
|
from collections import Counter
|
|
|
|
# hack for python2/3 compatibility
|
|
from io import open
|
|
argparse.open = open
|
|
|
|
def create_parser(subparsers=None):
|
|
|
|
if subparsers:
|
|
parser = subparsers.add_parser('get-vocab',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
description="Generates vocabulary")
|
|
else:
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
description="Generates vocabulary")
|
|
|
|
parser.add_argument(
|
|
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
|
|
metavar='PATH',
|
|
help="Input file (default: standard input).")
|
|
|
|
parser.add_argument(
|
|
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
|
|
metavar='PATH',
|
|
help="Output file (default: standard output)")
|
|
|
|
return parser
|
|
|
|
def get_vocab(train_file, vocab_file):
|
|
|
|
c = Counter()
|
|
|
|
for line in train_file:
|
|
for word in line.strip('\r\n ').split(' '):
|
|
if word:
|
|
c[word] += 1
|
|
|
|
for key,f in sorted(c.items(), key=lambda x: x[1], reverse=True):
|
|
vocab_file.write(key+" "+ str(f) + "\n")
|
|
|
|
if __name__ == "__main__":
|
|
|
|
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
|
newdir = os.path.join(currentdir, 'subword_nmt')
|
|
if os.path.isdir(newdir):
|
|
warnings.simplefilter('default')
|
|
warnings.warn(
|
|
"this script's location has moved to {0}. This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command 'subword-nmt'".format(newdir),
|
|
DeprecationWarning
|
|
)
|
|
|
|
# python 2/3 compatibility
|
|
if sys.version_info < (3, 0):
|
|
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
|
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
|
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
|
else:
|
|
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
|
|
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
|
|
sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)
|
|
|
|
parser = create_parser()
|
|
args = parser.parse_args()
|
|
|
|
# read/write files as UTF-8
|
|
if args.input.name != '<stdin>':
|
|
args.input = codecs.open(args.input.name, encoding='utf-8')
|
|
if args.output.name != '<stdout>':
|
|
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
|
|
|
get_vocab(args.input, args.output) |