2017-12-21 23:38:22 +03:00
#!/usr/bin/env python
2017-02-10 14:06:52 +03:00
# -*- coding: utf-8 -*-
# Author: Rico Sennrich
""" Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text.
This script learns BPE jointly on a concatenation of a list of texts ( typically the source and target side of a parallel corpus ,
applies the learned operation to each and ( optionally ) returns the resulting vocabulary of each text .
The vocabulary can be used in apply_bpe . py to avoid producing symbols that are rare or OOV in a training text .
Reference :
Rico Sennrich , Barry Haddow and Alexandra Birch ( 2016 ) . Neural Machine Translation of Rare Words with Subword Units .
Proceedings of the 54 th Annual Meeting of the Association for Computational Linguistics ( ACL 2016 ) . Berlin , Germany .
"""
from __future__ import unicode_literals
import sys
import os
2018-05-16 16:35:47 +03:00
import inspect
2017-02-10 14:06:52 +03:00
import codecs
import argparse
import tempfile
2018-05-16 16:35:47 +03:00
import warnings
2017-02-10 14:06:52 +03:00
from collections import Counter
2020-06-17 20:04:38 +03:00
from multiprocessing import cpu_count
2017-02-10 14:06:52 +03:00
2018-05-21 12:53:59 +03:00
#hack to get imports working if running this as a script, or within a package
if __name__ == ' __main__ ' :
import learn_bpe
import apply_bpe
else :
from . import learn_bpe
from . import apply_bpe
2017-02-10 14:06:52 +03:00
# hack for python2/3 compatibility
from io import open
argparse . open = open
2018-05-16 14:22:01 +03:00
def create_parser ( subparsers = None ) :
if subparsers :
parser = subparsers . add_parser ( ' learn-joint-bpe-and-vocab ' ,
formatter_class = argparse . RawDescriptionHelpFormatter ,
description = " learn BPE-based word segmentation " )
else :
parser = argparse . ArgumentParser (
formatter_class = argparse . RawDescriptionHelpFormatter ,
description = " learn BPE-based word segmentation " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
' --input ' , ' -i ' , type = argparse . FileType ( ' r ' ) , required = True , nargs = ' + ' ,
metavar = ' PATH ' ,
help = " Input texts (multiple allowed). " )
parser . add_argument (
' --output ' , ' -o ' , type = argparse . FileType ( ' w ' ) , required = True ,
metavar = ' PATH ' ,
help = " Output file for BPE codes. " )
parser . add_argument (
' --symbols ' , ' -s ' , type = int , default = 10000 ,
2020-06-17 20:04:38 +03:00
help = " Create this many new symbols (each representing a character n-gram) (default: %(default)s ) " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
' --separator ' , type = str , default = ' @@ ' , metavar = ' STR ' ,
2020-06-17 20:04:38 +03:00
help = " Separator between non-final subword units (default: ' %(default)s ' ) " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
2018-05-16 14:22:01 +03:00
' --write-vocabulary ' , type = argparse . FileType ( ' w ' ) , required = True , nargs = ' + ' , default = None ,
2017-02-10 14:27:47 +03:00
metavar = ' PATH ' , dest = ' vocab ' ,
2017-02-10 14:06:52 +03:00
help = ' Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py ' )
parser . add_argument (
' --min-frequency ' , type = int , default = 2 , metavar = ' FREQ ' ,
2020-06-17 20:04:38 +03:00
help = ' Stop if no symbol pair has frequency >= FREQ (default: %(default)s ) ' )
2018-08-20 14:07:21 +03:00
parser . add_argument (
' --total-symbols ' , ' -t ' , action = " store_true " ,
help = " subtract number of characters from the symbols to be generated (so that ' --symbols ' becomes an estimate for the total number of symbols needed to encode text). " )
2020-06-17 20:04:38 +03:00
parser . add_argument (
' --num-workers ' , type = int , default = 1 ,
help = " Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s ) " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
' --verbose ' , ' -v ' , action = " store_true " ,
help = " verbose mode. " )
return parser
2018-05-16 14:22:01 +03:00
def learn_joint_bpe_and_vocab ( args ) :
2017-02-10 14:06:52 +03:00
if args . vocab and len ( args . input ) != len ( args . vocab ) :
sys . stderr . write ( ' Error: number of input files and vocabulary files must match \n ' )
sys . exit ( 1 )
2017-02-10 14:27:47 +03:00
# read/write files as UTF-8
args . input = [ codecs . open ( f . name , encoding = ' UTF-8 ' ) for f in args . input ]
args . vocab = [ codecs . open ( f . name , ' w ' , encoding = ' UTF-8 ' ) for f in args . vocab ]
2017-02-10 14:06:52 +03:00
# get combined vocabulary of all input texts
full_vocab = Counter ( )
for f in args . input :
2020-06-17 20:04:38 +03:00
full_vocab + = learn_bpe . get_vocabulary ( f , num_workers = args . num_workers )
2017-02-10 14:27:47 +03:00
f . seek ( 0 )
2017-02-10 14:06:52 +03:00
2017-04-01 23:25:05 +03:00
vocab_list = [ ' {0} {1} ' . format ( key , freq ) for ( key , freq ) in full_vocab . items ( ) ]
2017-02-10 14:06:52 +03:00
# learn BPE on combined vocabulary
with codecs . open ( args . output . name , ' w ' , encoding = ' UTF-8 ' ) as output :
2018-08-20 14:07:21 +03:00
learn_bpe . learn_bpe ( vocab_list , output , args . symbols , args . min_frequency , args . verbose , is_dict = True , total_symbols = args . total_symbols )
2017-02-10 14:06:52 +03:00
2017-02-10 14:27:47 +03:00
with codecs . open ( args . output . name , encoding = ' UTF-8 ' ) as codes :
2018-07-18 00:36:11 +03:00
bpe = apply_bpe . BPE ( codes , separator = args . separator )
2017-02-10 14:06:52 +03:00
# apply BPE to each training corpus and get vocabulary
for train_file , vocab_file in zip ( args . input , args . vocab ) :
tmp = tempfile . NamedTemporaryFile ( delete = False )
tmp . close ( )
tmpout = codecs . open ( tmp . name , ' w ' , encoding = ' UTF-8 ' )
train_file . seek ( 0 )
2020-06-17 20:04:38 +03:00
bpe . process_lines ( train_file . name , tmpout , num_workers = args . num_workers )
2017-02-10 14:06:52 +03:00
tmpout . close ( )
2017-02-10 14:27:47 +03:00
tmpin = codecs . open ( tmp . name , encoding = ' UTF-8 ' )
2017-02-10 14:06:52 +03:00
2020-06-17 20:04:38 +03:00
vocab = learn_bpe . get_vocabulary ( tmpin , num_workers = args . num_workers )
2017-02-10 14:06:52 +03:00
tmpin . close ( )
os . remove ( tmp . name )
for key , freq in sorted ( vocab . items ( ) , key = lambda x : x [ 1 ] , reverse = True ) :
vocab_file . write ( " {0} {1} \n " . format ( key , freq ) )
vocab_file . close ( )
2018-05-16 14:22:01 +03:00
if __name__ == ' __main__ ' :
2018-05-16 16:35:47 +03:00
currentdir = os . path . dirname ( os . path . abspath ( inspect . getfile ( inspect . currentframe ( ) ) ) )
newdir = os . path . join ( currentdir , ' subword_nmt ' )
if os . path . isdir ( newdir ) :
warnings . simplefilter ( ' default ' )
warnings . warn (
" this script ' s location has moved to {0} . This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command ' subword-nmt ' " . format ( newdir ) ,
DeprecationWarning
)
2018-05-16 14:22:01 +03:00
# python 2/3 compatibility
if sys . version_info < ( 3 , 0 ) :
sys . stderr = codecs . getwriter ( ' UTF-8 ' ) ( sys . stderr )
sys . stdout = codecs . getwriter ( ' UTF-8 ' ) ( sys . stdout )
sys . stdin = codecs . getreader ( ' UTF-8 ' ) ( sys . stdin )
else :
sys . stderr = codecs . getwriter ( ' UTF-8 ' ) ( sys . stderr . buffer )
sys . stdout = codecs . getwriter ( ' UTF-8 ' ) ( sys . stdout . buffer )
sys . stdin = codecs . getreader ( ' UTF-8 ' ) ( sys . stdin . buffer )
parser = create_parser ( )
args = parser . parse_args ( )
2020-06-17 20:04:38 +03:00
if args . num_workers < = 0 :
args . num_workers = cpu_count ( )
2018-07-18 00:36:11 +03:00
if sys . version_info < ( 3 , 0 ) :
args . separator = args . separator . decode ( ' UTF-8 ' )
2020-06-17 20:04:38 +03:00
if args . num_workers > 1 :
args . num_workers = 1
warnings . warn ( " Parallel mode is only supported in Python3. Using 1 processor instead. " )
2018-07-18 00:36:11 +03:00
2018-05-16 14:22:01 +03:00
assert ( len ( args . input ) == len ( args . vocab ) )
2018-05-21 12:53:59 +03:00
learn_joint_bpe_and_vocab ( args )