initial commit

This commit is contained in:
Rico Sennrich 2015-09-01 10:45:31 +01:00
commit 83b1847647
6 changed files with 470 additions and 0 deletions

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2015 University of Edinburgh
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

34
README.md Normal file
View File

@ -0,0 +1,34 @@
Subword Neural Machine Translation
==================================
This repository contains preprocessing scripts to segment text into subword
units. The primary purpose is to facilitate the reproduction of our experiments
on Neural Machine Translation with subword units (see below for reference).
USAGE INSTRUCTIONS
------------------
Check the individual files for usage instructions.
To apply byte pair encoding to word segmentation, invoke these commands:
./learn_bpe.py -s {num_operations} < {train_file} > {codes_file}
./apply_bpe.py -c {codes_file} < {test_file}
To segment rare words into character n-grams, do the following:
./get_vocab.py < {train_file} > {vocab_file}
./segment-char-ngrams.py --vocab {vocab_file} -n {order} --shortlist {size} < {test_file}
The original segmentation can be restored with a simple replacement:
sed "s/@@ //g"
PUBLICATIONS
------------
The segmentation methods are described in:
Rico Sennrich, Barry Haddow and Alexandra Birch (2015):
Neural Machine Translation of Rare Words with Subword Units
http://arxiv.org/abs/1508.07909

131
apply_bpe.py Executable file
View File

@ -0,0 +1,131 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Author: Rico Sennrich
"""Use operations learned with learn_bpe.py to encode a new text.
The text will not be smaller, but use only a fixed vocabulary, with rare words
encoded as variable-length sequences of subword units.
Reference:
Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units.
"""
from __future__ import unicode_literals, division
import sys
import codecs
import argparse
from collections import defaultdict
# hack for python2/3 compatibility
from io import open
argparse.open = open
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
def create_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="learn BPE-based word segmentation")
parser.add_argument(
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
metavar='PATH',
help="Input file (default: standard input).")
parser.add_argument(
'--codes', '-c', type=argparse.FileType('r'), metavar='PATH',
required=True,
help="File with BPE codes (created by learn_bpe.py).")
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
metavar='PATH',
help="Output file (default: standard output)")
parser.add_argument(
'--separator', '-s', type=str, default='@@', metavar='STR',
help="Separator between non-final subword units (default: '%(default)s'))")
return parser
def get_pairs(word):
"""Return set of symbol pairs in a word.
word is represented as tuple of symbols (symbols being variable-length strings)
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def encode(orig, bpe_codes, cache={}):
"""Encode word based on list of BPE merge operations, which are applied consecutively
"""
if orig in cache:
return cache[orig]
word = tuple(orig) + ('</w>',)
pairs = get_pairs(word)
while True:
bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
if bigram not in bpe_codes:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
# don't print end-of-word symbols
if word[-1] == '</w>':
word = word[:-1]
elif word[-1].endswith('</w>'):
word = word[:-1] + (word[-1].replace('</w>',''),)
cache[orig] = word
return word
if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
bpe_codes = [tuple(item.split()) for item in args.codes]
# some hacking to deal with duplicates (only consider first instance)
bpe_codes = dict([(code,i) for (i,code) in reversed(list(enumerate(bpe_codes)))])
for line in args.input:
for word in line.split():
new_word = encode(word, bpe_codes)
for item in new_word[:-1]:
args.output.write(item + args.separator + ' ')
args.output.write(new_word[-1] + ' ')
args.output.write('\n')

13
get_vocab.py Executable file
View File

@ -0,0 +1,13 @@
#! /usr/bin/env python
import sys
from collections import Counter
c = Counter()
for line in sys.stdin:
for word in line.split():
c[word] += 1
for key,f in sorted(c.items(), key=lambda x: x[1], reverse=True):
print key, f

200
learn_bpe.py Executable file
View File

@ -0,0 +1,200 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Author: Rico Sennrich
"""Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text.
Unlike the original BPE, it does not compress the plain text, but can be used to reduce the vocabulary
of a text to a configurable number of symbols, with only a small increase in the number of tokens.
Reference:
Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units.
"""
from __future__ import unicode_literals
import sys
import codecs
import re
import copy
import argparse
from collections import defaultdict, Counter
# hack for python2/3 compatibility
from io import open
argparse.open = open
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
def create_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="learn BPE-based word segmentation")
parser.add_argument(
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
metavar='PATH',
help="Input text (default: standard input).")
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
metavar='PATH',
help="Output file for BPE codes (default: standard output)")
parser.add_argument(
'--symbols', '-s', type=int, default=10000,
help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))")
return parser
def get_vocabulary(fobj):
"""Read text and return dictionary that encodes vocabulary
"""
vocab = Counter()
for line in fobj:
for word in line.split():
vocab[word] += 1
return vocab
def update_pair_statistics(pair, changed, stats, indices):
"""Minimally update the indices and frequency of symbol pairs
if we merge a pair of symbols, only pairs that overlap with occurrences
of this pair are affected, and need to be updated.
"""
stats[pair] = 0
indices[pair] = defaultdict(int)
first, second = pair
new_pair = first+second
for j, word, old_word, freq in changed:
# find all instances of pair, and update frequency/indices around it
i = 0
while True:
try:
i = old_word.index(first, i)
except ValueError:
break
if i < len(old_word)-1 and old_word[i+1] == second:
if i:
prev = old_word[i-1:i+1]
stats[prev] -= freq
indices[prev][j] -= 1
if i < len(old_word)-2:
# don't double-count consecutive pairs
if old_word[i+2] != first or i >= len(old_word)-3 or old_word[i+3] != second:
nex = old_word[i+1:i+3]
stats[nex] -= freq
indices[nex][j] -= 1
i += 2
else:
i += 1
i = 0
while True:
try:
i = word.index(new_pair, i)
except ValueError:
break
if i:
prev = word[i-1:i+1]
stats[prev] += freq
indices[prev][j] += 1
# don't double-count consecutive pairs
if i < len(word)-1 and word[i+1] != new_pair:
nex = word[i:i+2]
stats[nex] += freq
indices[nex][j] += 1
i += 1
def get_pair_statistics(vocab):
"""Count frequency of all symbol pairs, and create index"""
# data structure of pair frequencies
stats = defaultdict(int)
#index from pairs to words
indices = defaultdict(lambda: defaultdict(int))
for i, (word, freq) in enumerate(vocab):
prev_char = word[0]
for char in word[1:]:
stats[prev_char, char] += freq
indices[prev_char, char][i] += 1
prev_char = char
return stats, indices
def replace_pair(pair, vocab, indices):
"""Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'"""
first, second = pair
pair_str = ''.join(pair)
changes = []
pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
for j, freq in indices[pair].items():
if freq < 1:
continue
word, freq = vocab[j]
new_word = ' '.join(word)
new_word = pattern.sub(pair_str, new_word)
new_word = tuple(new_word.split())
vocab[j] = (new_word, freq)
changes.append((j, new_word, word, freq))
return changes
def prune_stats(stats, big_stats, threshold):
"""Prune statistics dict for efficiency of max()
The frequency of a symbol pair never increases, so pruning is generally safe
(until we the most frequent pair is less frequent than a pair we previously pruned)
big_stats keeps full statistics for when we need to access pruned items
"""
for item,freq in list(stats.items()):
if freq < threshold:
del stats[item]
if freq < 0:
big_stats[item] += freq
else:
big_stats[item] = freq
if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
vocab = get_vocabulary(args.input)
vocab = dict([(tuple(x)+('</w>',) ,y) for (x,y) in vocab.items()])
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True)
stats, indices = get_pair_statistics(sorted_vocab)
big_stats = copy.deepcopy(stats)
# threshold is inspired by Zipfian assumption, but should only affect speed
threshold = max(stats.values()) / 10
for i in range(args.symbols):
most_frequent = max(stats, key=stats.get)
# we probably missed the best pair because of pruning; go back to full statistics
if i and stats[most_frequent] < threshold:
prune_stats(stats, big_stats, threshold)
stats = copy.deepcopy(big_stats)
most_frequent = max(stats, key=stats.get)
# threshold is inspired by Zipfian assumption, but should only affect speed
threshold = stats[most_frequent] * i/(i+10000.0)
prune_stats(stats, big_stats, threshold)
if stats[most_frequent] < 2:
sys.stderr.write('no pair has frequency > 1. Stopping\n')
break
sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent]))
args.output.write('{0} {1}\n'.format(*most_frequent))
changes = replace_pair(most_frequent, sorted_vocab, indices)
update_pair_statistics(most_frequent, changes, stats, indices)
stats[most_frequent] = 0
if not i % 100:
prune_stats(stats, big_stats, threshold)

71
segment-char-ngrams.py Executable file
View File

@ -0,0 +1,71 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Author: Rico Sennrich
from __future__ import unicode_literals, division
import sys
import codecs
import argparse
# hack for python2/3 compatibility
from io import open
argparse.open = open
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
def create_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="segment rare words into character n-grams")
parser.add_argument(
'--input', '-i', type=argparse.FileType('r'), default=sys.stdin,
metavar='PATH',
help="Input file (default: standard input).")
parser.add_argument(
'--vocab', type=argparse.FileType('r'), metavar='PATH',
required=True,
help="Vocabulary file.")
parser.add_argument(
'--shortlist', type=int, metavar='INT', default=0,
help="do not segment INT most frequent words in vocabulary (default: '%(default)s')).")
parser.add_argument(
'-n', type=int, metavar='INT', default=2,
help="segment rare words into character n-grams of size INT (default: '%(default)s')).")
parser.add_argument(
'--output', '-o', type=argparse.FileType('w'), default=sys.stdout,
metavar='PATH',
help="Output file (default: standard output)")
parser.add_argument(
'--separator', '-s', type=str, default='@@', metavar='STR',
help="Separator between non-final subword units (default: '%(default)s'))")
return parser
if __name__ == '__main__':
parser = create_parser()
args = parser.parse_args()
vocab = [line.split()[0] for line in args.vocab if len(line.split()) == 2]
vocab = dict((y,x) for (x,y) in enumerate(vocab))
for line in args.input:
for word in line.split():
if word not in vocab or vocab[word] > args.shortlist:
i = 0
while i*args.n < len(word):
args.output.write(word[i*args.n:i*args.n+args.n])
i += 1
if i*args.n < len(word):
args.output.write(args.separator)
args.output.write(' ')
else:
args.output.write(word + ' ')
args.output.write('\n')