diff --git a/learn_bpe.py b/learn_bpe.py index 5e92b2c..034792f 100755 --- a/learn_bpe.py +++ b/learn_bpe.py @@ -43,6 +43,7 @@ def create_parser(): '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, metavar='PATH', help="Input text (default: standard input).") + parser.add_argument( '--output', '-o', type=argparse.FileType('w'), default=sys.stdout, metavar='PATH', @@ -53,11 +54,12 @@ def create_parser(): parser.add_argument( '--min-frequency', type=int, default=2, metavar='FREQ', help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))') + parser.add_argument('--dict-input', action="store_true", + help="If set, input file is interpreted as a dictionary where each line contains a word-count pair") parser.add_argument( '--verbose', '-v', action="store_true", help="verbose mode.") - parser.add_argument('--is_dict', '-is_dict', type=bool, default=False, - help="Specify this argument if the input file is a dictionary where each line contains a word count pair") + return parser def get_vocabulary(fobj, is_dict=False): @@ -66,8 +68,8 @@ def get_vocabulary(fobj, is_dict=False): vocab = Counter() for line in fobj: if is_dict: - word_count = line.strip().split() - vocab[word_count[0]] = int(word_count[1]) + word, count = line.strip().split() + vocab[word] = int(count) else: for word in line.split(): vocab[word] += 1 @@ -194,7 +196,7 @@ if __name__ == '__main__': if args.output.name != '': args.output = codecs.open(args.output.name, 'w', encoding='utf-8') - vocab = get_vocabulary(args.input, is_dict = args.is_dict) + vocab = get_vocabulary(args.input, is_dict = args.dict_input) vocab = dict([(tuple(x)+('',) ,y) for (x,y) in vocab.items()]) sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)