2017-12-21 23:38:22 +03:00
#!/usr/bin/env python
2017-02-10 14:06:52 +03:00
# -*- coding: utf-8 -*-
# Author: Rico Sennrich
""" Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text.
This script learns BPE jointly on a concatenation of a list of texts ( typically the source and target side of a parallel corpus ,
applies the learned operation to each and ( optionally ) returns the resulting vocabulary of each text .
The vocabulary can be used in apply_bpe . py to avoid producing symbols that are rare or OOV in a training text .
Reference :
Rico Sennrich , Barry Haddow and Alexandra Birch ( 2016 ) . Neural Machine Translation of Rare Words with Subword Units .
Proceedings of the 54 th Annual Meeting of the Association for Computational Linguistics ( ACL 2016 ) . Berlin , Germany .
"""
from __future__ import unicode_literals
import sys
import os
2018-05-16 16:35:47 +03:00
import inspect
2017-02-10 14:06:52 +03:00
import codecs
import argparse
import tempfile
2018-05-16 16:35:47 +03:00
import warnings
2017-02-10 14:06:52 +03:00
from collections import Counter
2020-06-17 20:04:38 +03:00
from multiprocessing import cpu_count
2017-02-10 14:06:52 +03:00
2018-05-21 12:53:59 +03:00
#hack to get imports working if running this as a script, or within a package
if __name__ == ' __main__ ' :
import learn_bpe
import apply_bpe
else :
from . import learn_bpe
from . import apply_bpe
2017-02-10 14:06:52 +03:00
2018-05-16 14:22:01 +03:00
def create_parser ( subparsers = None ) :
if subparsers :
parser = subparsers . add_parser ( ' learn-joint-bpe-and-vocab ' ,
formatter_class = argparse . RawDescriptionHelpFormatter ,
description = " learn BPE-based word segmentation " )
else :
parser = argparse . ArgumentParser (
formatter_class = argparse . RawDescriptionHelpFormatter ,
description = " learn BPE-based word segmentation " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
2024-07-24 15:12:50 +03:00
' --input ' , ' -i ' , type = argparse . FileType ( ' rb ' ) , required = True , nargs = ' + ' ,
2017-02-10 14:06:52 +03:00
metavar = ' PATH ' ,
help = " Input texts (multiple allowed). " )
parser . add_argument (
2024-07-24 15:12:50 +03:00
' --output ' , ' -o ' , type = argparse . FileType ( ' wb ' ) , required = True ,
2017-02-10 14:06:52 +03:00
metavar = ' PATH ' ,
help = " Output file for BPE codes. " )
parser . add_argument (
' --symbols ' , ' -s ' , type = int , default = 10000 ,
2020-06-17 20:04:38 +03:00
help = " Create this many new symbols (each representing a character n-gram) (default: %(default)s ) " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
2024-07-24 15:12:50 +03:00
' --byte ' , ' -b ' , action = " store_true " ,
help = " byte-level BPE. " )
parser . add_argument (
' --separator ' , type = bytes , default = b ' @@ ' , metavar = ' STR ' ,
2020-06-17 20:04:38 +03:00
help = " Separator between non-final subword units (default: ' %(default)s ' ) " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
2024-07-24 15:12:50 +03:00
' --write-vocabulary ' , type = argparse . FileType ( ' wb ' ) , required = True , nargs = ' + ' , default = None ,
2017-02-10 14:27:47 +03:00
metavar = ' PATH ' , dest = ' vocab ' ,
2017-02-10 14:06:52 +03:00
help = ' Write to these vocabulary files after applying BPE. One per input text. Used for filtering in apply_bpe.py ' )
parser . add_argument (
' --min-frequency ' , type = int , default = 2 , metavar = ' FREQ ' ,
2020-06-17 20:04:38 +03:00
help = ' Stop if no symbol pair has frequency >= FREQ (default: %(default)s ) ' )
2018-08-20 14:07:21 +03:00
parser . add_argument (
' --total-symbols ' , ' -t ' , action = " store_true " ,
help = " subtract number of characters from the symbols to be generated (so that ' --symbols ' becomes an estimate for the total number of symbols needed to encode text). " )
2020-06-17 20:04:38 +03:00
parser . add_argument (
' --num-workers ' , type = int , default = 1 ,
help = " Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. (default: %(default)s ) " )
2017-02-10 14:06:52 +03:00
parser . add_argument (
' --verbose ' , ' -v ' , action = " store_true " ,
help = " verbose mode. " )
return parser
2018-05-16 14:22:01 +03:00
def learn_joint_bpe_and_vocab ( args ) :
2017-02-10 14:06:52 +03:00
if args . vocab and len ( args . input ) != len ( args . vocab ) :
sys . stderr . write ( ' Error: number of input files and vocabulary files must match \n ' )
sys . exit ( 1 )
2024-07-24 15:12:50 +03:00
if args . byte :
# read/write files as byte streams
args . input = [ codecs . open ( f . name , ' rb ' ) for f in args . input ]
args . vocab = [ codecs . open ( f . name , ' wb ' ) for f in args . vocab ]
else :
# read/write files as UTF-8
args . input = [ codecs . open ( f . name , encoding = ' UTF-8 ' ) for f in args . input ]
args . vocab = [ codecs . open ( f . name , ' w ' , encoding = ' UTF-8 ' ) for f in args . vocab ]
args . separator = args . separator . decode ( ' UTF-8 ' ) if not args . byte else args . separator
2017-02-10 14:27:47 +03:00
2017-02-10 14:06:52 +03:00
# get combined vocabulary of all input texts
full_vocab = Counter ( )
for f in args . input :
2024-07-24 15:12:50 +03:00
full_vocab + = learn_bpe . get_vocabulary ( f , num_workers = args . num_workers , is_bytes = args . byte )
2017-02-10 14:27:47 +03:00
f . seek ( 0 )
2017-02-10 14:06:52 +03:00
2024-07-24 15:12:50 +03:00
if args . byte :
vocab_list = [ key + b ' ' + str ( freq ) . encode ( ' UTF-8 ' ) for ( key , freq ) in full_vocab . items ( ) ]
else :
vocab_list = [ ' {0} {1} ' . format ( key , freq ) for ( key , freq ) in full_vocab . items ( ) ]
2017-04-01 23:25:05 +03:00
2017-02-10 14:06:52 +03:00
# learn BPE on combined vocabulary
2024-07-24 15:12:50 +03:00
if args . byte :
with open ( args . output . name , ' wb ' ) as output :
learn_bpe . learn_bpe ( vocab_list , output , args . symbols , args . min_frequency , args . verbose , is_dict = True , is_bytes = args . byte , total_symbols = args . total_symbols )
with open ( args . output . name , ' rb ' ) as codes :
bpe = apply_bpe . BPE ( codes , separator = args . separator , is_bytes = args . byte )
else :
with codecs . open ( args . output . name , ' w ' , encoding = ' UTF-8 ' ) as output :
learn_bpe . learn_bpe ( vocab_list , output , args . symbols , args . min_frequency , args . verbose , is_dict = True , is_bytes = args . byte , total_symbols = args . total_symbols )
with codecs . open ( args . output . name , encoding = ' UTF-8 ' ) as codes :
bpe = apply_bpe . BPE ( codes , separator = args . separator , is_bytes = args . byte )
2017-02-10 14:06:52 +03:00
# apply BPE to each training corpus and get vocabulary
for train_file , vocab_file in zip ( args . input , args . vocab ) :
2024-07-24 15:12:50 +03:00
# read/write files as UTF-8
if not args . byte :
train_file = codecs . open ( train_file . name , encoding = ' utf-8 ' )
vocab_file = codecs . open ( vocab_file . name , ' w ' , encoding = ' utf-8 ' )
2017-02-10 14:06:52 +03:00
tmp = tempfile . NamedTemporaryFile ( delete = False )
tmp . close ( )
2024-07-24 15:12:50 +03:00
if args . byte :
tmpout = open ( tmp . name , ' wb ' )
else :
tmpout = codecs . open ( tmp . name , ' w ' , encoding = ' UTF-8 ' )
2017-02-10 14:06:52 +03:00
train_file . seek ( 0 )
2020-06-17 20:04:38 +03:00
bpe . process_lines ( train_file . name , tmpout , num_workers = args . num_workers )
2017-02-10 14:06:52 +03:00
tmpout . close ( )
2024-07-24 15:12:50 +03:00
if args . byte :
tmpin = open ( tmp . name , ' rb ' )
else :
tmpin = codecs . open ( tmp . name , encoding = ' UTF-8 ' )
vocab = learn_bpe . get_vocabulary ( tmpin , num_workers = args . num_workers , is_bytes = args . byte )
2017-02-10 14:06:52 +03:00
tmpin . close ( )
os . remove ( tmp . name )
for key , freq in sorted ( vocab . items ( ) , key = lambda x : x [ 1 ] , reverse = True ) :
2024-07-24 15:12:50 +03:00
if args . byte :
vocab_file . write ( key + b " " + str ( freq ) . encode ( ' utf-8 ' ) + b " \n " )
else :
vocab_file . write ( " {0} {1} \n " . format ( key , freq ) )
2022-09-05 15:30:47 +03:00
train_file . close ( )
2017-02-10 14:06:52 +03:00
vocab_file . close ( )
2018-05-16 14:22:01 +03:00
if __name__ == ' __main__ ' :
2018-05-16 16:35:47 +03:00
currentdir = os . path . dirname ( os . path . abspath ( inspect . getfile ( inspect . currentframe ( ) ) ) )
newdir = os . path . join ( currentdir , ' subword_nmt ' )
if os . path . isdir ( newdir ) :
warnings . warn (
" this script ' s location has moved to {0} . This symbolic link will be removed in a future version. Please point to the new location, or install the package and use the command ' subword-nmt ' " . format ( newdir ) ,
DeprecationWarning
)
2018-05-16 14:22:01 +03:00
# python 2/3 compatibility
if sys . version_info < ( 3 , 0 ) :
2024-07-24 15:12:50 +03:00
print ( " Python 2 is deprecated. Use Python 3 " )
sys . exit ( 1 )
sys . stderr = codecs . getwriter ( ' UTF-8 ' ) ( sys . stderr . buffer )
sys . stdout = codecs . getwriter ( ' UTF-8 ' ) ( sys . stdout . buffer )
sys . stdin = codecs . getreader ( ' UTF-8 ' ) ( sys . stdin . buffer )
2018-05-16 14:22:01 +03:00
parser = create_parser ( )
args = parser . parse_args ( )
2020-06-17 20:04:38 +03:00
if args . num_workers < = 0 :
args . num_workers = cpu_count ( )
2018-05-16 14:22:01 +03:00
assert ( len ( args . input ) == len ( args . vocab ) )
2018-05-21 12:53:59 +03:00
learn_joint_bpe_and_vocab ( args )