mirror of
https://github.com/rsennrich/subword-nmt.git
synced 2024-11-25 21:34:58 +03:00
initial commit
This commit is contained in:
commit
83b1847647
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 University of Edinburgh
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
34
README.md
Normal file
34
README.md
Normal file
@ -0,0 +1,34 @@
|
||||
Subword Neural Machine Translation
|
||||
==================================
|
||||
|
||||
This repository contains preprocessing scripts to segment text into subword
|
||||
units. The primary purpose is to facilitate the reproduction of our experiments
|
||||
on Neural Machine Translation with subword units (see below for reference).
|
||||
|
||||
USAGE INSTRUCTIONS
|
||||
------------------
|
||||
|
||||
Check the individual files for usage instructions.
|
||||
|
||||
To apply byte pair encoding to word segmentation, invoke these commands:
|
||||
|
||||
./learn_bpe.py -s {num_operations} < {train_file} > {codes_file}
|
||||
./apply_bpe.py -c {codes_file} < {test_file}
|
||||
|
||||
To segment rare words into character n-grams, do the following:
|
||||
|
||||
./get_vocab.py < {train_file} > {vocab_file}
|
||||
./segment-char-ngrams.py --vocab {vocab_file} -n {order} --shortlist {size} < {test_file}
|
||||
|
||||
The original segmentation can be restored with a simple replacement:
|
||||
|
||||
sed "s/@@ //g"
|
||||
|
||||
PUBLICATIONS
|
||||
------------
|
||||
|
||||
The segmentation methods are described in:
|
||||
|
||||
Rico Sennrich, Barry Haddow and Alexandra Birch (2015):
|
||||
Neural Machine Translation of Rare Words with Subword Units
|
||||
http://arxiv.org/abs/1508.07909
|
131
apply_bpe.py
Executable file
131
apply_bpe.py
Executable file
@ -0,0 +1,131 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Author: Rico Sennrich
|
||||
|
||||
"""Use operations learned with learn_bpe.py to encode a new text.
|
||||
The text will not be smaller, but use only a fixed vocabulary, with rare words
|
||||
encoded as variable-length sequences of subword units.
|
||||
|
||||
Reference:
|
||||
Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units.
|
||||
"""
|
||||
|
||||
from __future__ import unicode_literals, division
|
||||
|
||||
import sys
|
||||
import codecs
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
from io import open
|
||||
argparse.open = open
|
||||
|
||||
# python 2/3 compatibility
|
||||
if sys.version_info < (3, 0):
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="learn BPE-based word segmentation")
|
||||
|
||||
parser.add_argument(
|
||||
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
|
||||
metavar='PATH',
|
||||
help="Input file (default: standard input).")
|
||||
parser.add_argument(
|
||||
'--codes', '-c', type=argparse.FileType('r'), metavar='PATH',
|
||||
required=True,
|
||||
help="File with BPE codes (created by learn_bpe.py).")
|
||||
parser.add_argument(
|
||||
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
|
||||
metavar='PATH',
|
||||
help="Output file (default: standard output)")
|
||||
parser.add_argument(
|
||||
'--separator', '-s', type=str, default='@@', metavar='STR',
|
||||
help="Separator between non-final subword units (default: '%(default)s'))")
|
||||
|
||||
return parser
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
word is represented as tuple of symbols (symbols being variable-length strings)
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
def encode(orig, bpe_codes, cache={}):
|
||||
"""Encode word based on list of BPE merge operations, which are applied consecutively
|
||||
"""
|
||||
|
||||
if orig in cache:
|
||||
return cache[orig]
|
||||
|
||||
word = tuple(orig) + ('</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
|
||||
if bigram not in bpe_codes:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
|
||||
# don't print end-of-word symbols
|
||||
if word[-1] == '</w>':
|
||||
word = word[:-1]
|
||||
elif word[-1].endswith('</w>'):
|
||||
word = word[:-1] + (word[-1].replace('</w>',''),)
|
||||
|
||||
cache[orig] = word
|
||||
return word
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
bpe_codes = [tuple(item.split()) for item in args.codes]
|
||||
|
||||
# some hacking to deal with duplicates (only consider first instance)
|
||||
bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(bpe_codes)))])
|
||||
|
||||
for line in args.input:
|
||||
for word in line.split():
|
||||
new_word = encode(word, bpe_codes)
|
||||
|
||||
for item in new_word[:-1]:
|
||||
args.output.write(item + args.separator + ' ')
|
||||
args.output.write(new_word[-1] + ' ')
|
||||
|
||||
args.output.write('\n')
|
13
get_vocab.py
Executable file
13
get_vocab.py
Executable file
@ -0,0 +1,13 @@
|
||||
#! /usr/bin/env python
|
||||
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
||||
c = Counter()
|
||||
|
||||
for line in sys.stdin:
|
||||
for word in line.split():
|
||||
c[word] += 1
|
||||
|
||||
for key,f in sorted(c.items(), key=lambda x: x[1], reverse=True):
|
||||
print key, f
|
200
learn_bpe.py
Executable file
200
learn_bpe.py
Executable file
@ -0,0 +1,200 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Author: Rico Sennrich
|
||||
|
||||
"""Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text.
|
||||
Unlike the original BPE, it does not compress the plain text, but can be used to reduce the vocabulary
|
||||
of a text to a configurable number of symbols, with only a small increase in the number of tokens.
|
||||
|
||||
Reference:
|
||||
Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units.
|
||||
"""
|
||||
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import codecs
|
||||
import re
|
||||
import copy
|
||||
import argparse
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
from io import open
|
||||
argparse.open = open
|
||||
|
||||
# python 2/3 compatibility
|
||||
if sys.version_info < (3, 0):
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="learn BPE-based word segmentation")
|
||||
|
||||
parser.add_argument(
|
||||
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
|
||||
metavar='PATH',
|
||||
help="Input text (default: standard input).")
|
||||
parser.add_argument(
|
||||
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
|
||||
metavar='PATH',
|
||||
help="Output file for BPE codes (default: standard output)")
|
||||
parser.add_argument(
|
||||
'--symbols', '-s', type=int, default=10000,
|
||||
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
|
||||
|
||||
return parser
|
||||
|
||||
def get_vocabulary(fobj):
|
||||
"""Read text and return dictionary that encodes vocabulary
|
||||
"""
|
||||
vocab = Counter()
|
||||
for line in fobj:
|
||||
for word in line.split():
|
||||
vocab[word] += 1
|
||||
return vocab
|
||||
|
||||
def update_pair_statistics(pair, changed, stats, indices):
|
||||
"""Minimally update the indices and frequency of symbol pairs
|
||||
|
||||
if we merge a pair of symbols, only pairs that overlap with occurrences
|
||||
of this pair are affected, and need to be updated.
|
||||
"""
|
||||
stats[pair] = 0
|
||||
indices[pair] = defaultdict(int)
|
||||
first, second = pair
|
||||
new_pair = first+second
|
||||
for j, word, old_word, freq in changed:
|
||||
|
||||
# find all instances of pair, and update frequency/indices around it
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
i = old_word.index(first, i)
|
||||
except ValueError:
|
||||
break
|
||||
if i < len(old_word)-1 and old_word[i+1] == second:
|
||||
if i:
|
||||
prev = old_word[i-1:i+1]
|
||||
stats[prev] -= freq
|
||||
indices[prev][j] -= 1
|
||||
if i < len(old_word)-2:
|
||||
# don't double-count consecutive pairs
|
||||
if old_word[i+2] != first or i >= len(old_word)-3 or old_word[i+3] != second:
|
||||
nex = old_word[i+1:i+3]
|
||||
stats[nex] -= freq
|
||||
indices[nex][j] -= 1
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
i = word.index(new_pair, i)
|
||||
except ValueError:
|
||||
break
|
||||
if i:
|
||||
prev = word[i-1:i+1]
|
||||
stats[prev] += freq
|
||||
indices[prev][j] += 1
|
||||
# don't double-count consecutive pairs
|
||||
if i < len(word)-1 and word[i+1] != new_pair:
|
||||
nex = word[i:i+2]
|
||||
stats[nex] += freq
|
||||
indices[nex][j] += 1
|
||||
i += 1
|
||||
|
||||
|
||||
def get_pair_statistics(vocab):
|
||||
"""Count frequency of all symbol pairs, and create index"""
|
||||
|
||||
# data structure of pair frequencies
|
||||
stats = defaultdict(int)
|
||||
|
||||
#index from pairs to words
|
||||
indices = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
for i, (word, freq) in enumerate(vocab):
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
stats[prev_char, char] += freq
|
||||
indices[prev_char, char][i] += 1
|
||||
prev_char = char
|
||||
|
||||
return stats, indices
|
||||
|
||||
|
||||
def replace_pair(pair, vocab, indices):
|
||||
"""Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'"""
|
||||
first, second = pair
|
||||
pair_str = ''.join(pair)
|
||||
changes = []
|
||||
pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
|
||||
for j, freq in indices[pair].items():
|
||||
if freq < 1:
|
||||
continue
|
||||
word, freq = vocab[j]
|
||||
new_word = ' '.join(word)
|
||||
new_word = pattern.sub(pair_str, new_word)
|
||||
new_word = tuple(new_word.split())
|
||||
|
||||
vocab[j] = (new_word, freq)
|
||||
changes.append((j, new_word, word, freq))
|
||||
|
||||
return changes
|
||||
|
||||
def prune_stats(stats, big_stats, threshold):
|
||||
"""Prune statistics dict for efficiency of max()
|
||||
|
||||
The frequency of a symbol pair never increases, so pruning is generally safe
|
||||
(until we the most frequent pair is less frequent than a pair we previously pruned)
|
||||
big_stats keeps full statistics for when we need to access pruned items
|
||||
"""
|
||||
for item,freq in list(stats.items()):
|
||||
if freq < threshold:
|
||||
del stats[item]
|
||||
if freq < 0:
|
||||
big_stats[item] += freq
|
||||
else:
|
||||
big_stats[item] = freq
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
vocab = get_vocabulary(args.input)
|
||||
vocab = dict([(tuple(x)+('</w>',) ,y) for (x,y) in vocab.items()])
|
||||
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
stats, indices = get_pair_statistics(sorted_vocab)
|
||||
big_stats = copy.deepcopy(stats)
|
||||
# threshold is inspired by Zipfian assumption, but should only affect speed
|
||||
threshold = max(stats.values()) / 10
|
||||
for i in range(args.symbols):
|
||||
most_frequent = max(stats, key=stats.get)
|
||||
|
||||
# we probably missed the best pair because of pruning; go back to full statistics
|
||||
if i and stats[most_frequent] < threshold:
|
||||
prune_stats(stats, big_stats, threshold)
|
||||
stats = copy.deepcopy(big_stats)
|
||||
most_frequent = max(stats, key=stats.get)
|
||||
# threshold is inspired by Zipfian assumption, but should only affect speed
|
||||
threshold = stats[most_frequent] * i/(i+10000.0)
|
||||
prune_stats(stats, big_stats, threshold)
|
||||
|
||||
if stats[most_frequent] < 2:
|
||||
sys.stderr.write('no pair has frequency > 1. Stopping\n')
|
||||
break
|
||||
|
||||
sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent]))
|
||||
args.output.write('{0} {1}\n'.format(*most_frequent))
|
||||
changes = replace_pair(most_frequent, sorted_vocab, indices)
|
||||
update_pair_statistics(most_frequent, changes, stats, indices)
|
||||
stats[most_frequent] = 0
|
||||
if not i % 100:
|
||||
prune_stats(stats, big_stats, threshold)
|
71
segment-char-ngrams.py
Executable file
71
segment-char-ngrams.py
Executable file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Author: Rico Sennrich
|
||||
|
||||
from __future__ import unicode_literals, division
|
||||
|
||||
import sys
|
||||
import codecs
|
||||
import argparse
|
||||
|
||||
# hack for python2/3 compatibility
|
||||
from io import open
|
||||
argparse.open = open
|
||||
|
||||
# python 2/3 compatibility
|
||||
if sys.version_info < (3, 0):
|
||||
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
||||
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
||||
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description="segment rare words into character n-grams")
|
||||
|
||||
parser.add_argument(
|
||||
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
|
||||
metavar='PATH',
|
||||
help="Input file (default: standard input).")
|
||||
parser.add_argument(
|
||||
'--vocab', type=argparse.FileType('r'), metavar='PATH',
|
||||
required=True,
|
||||
help="Vocabulary file.")
|
||||
parser.add_argument(
|
||||
'--shortlist', type=int, metavar='INT', default=0,
|
||||
help="do not segment INT most frequent words in vocabulary (default: '%(default)s')).")
|
||||
parser.add_argument(
|
||||
'-n', type=int, metavar='INT', default=2,
|
||||
help="segment rare words into character n-grams of size INT (default: '%(default)s')).")
|
||||
parser.add_argument(
|
||||
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
|
||||
metavar='PATH',
|
||||
help="Output file (default: standard output)")
|
||||
parser.add_argument(
|
||||
'--separator', '-s', type=str, default='@@', metavar='STR',
|
||||
help="Separator between non-final subword units (default: '%(default)s'))")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
vocab = [line.split()[0] for line in args.vocab if len(line.split()) == 2]
|
||||
vocab = dict((y,x) for (x,y) in enumerate(vocab))
|
||||
|
||||
for line in args.input:
|
||||
for word in line.split():
|
||||
if word not in vocab or vocab[word] > args.shortlist:
|
||||
i = 0
|
||||
while i*args.n < len(word):
|
||||
args.output.write(word[i*args.n:i*args.n+args.n])
|
||||
i += 1
|
||||
if i*args.n < len(word):
|
||||
args.output.write(args.separator)
|
||||
args.output.write(' ')
|
||||
else:
|
||||
args.output.write(word + ' ')
|
||||
args.output.write('\n')
|
Loading…
Reference in New Issue
Block a user