#!/usr/bin/python # -*- coding: utf-8 -*- # Author: Rico Sennrich """Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text. Unlike the original BPE, it does not compress the plain text, but can be used to reduce the vocabulary of a text to a configurable number of symbols, with only a small increase in the number of tokens. Reference: Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units. """ from __future__ import unicode_literals import sys import codecs import re import copy import argparse from collections import defaultdict, Counter # hack for python2/3 compatibility from io import open argparse.open = open # python 2/3 compatibility if sys.version_info < (3, 0): sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) sys.stdin = codecs.getreader('UTF-8')(sys.stdin) def create_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description="learn BPE-based word segmentation") parser.add_argument( '--input', '-i', type=argparse.FileType('r'), default=sys.stdin, metavar='PATH', help="Input text (default: standard input).") parser.add_argument( '--output', '-o', type=argparse.FileType('w'), default=sys.stdout, metavar='PATH', help="Output file for BPE codes (default: standard output)") parser.add_argument( '--symbols', '-s', type=int, default=10000, help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))") return parser def get_vocabulary(fobj): """Read text and return dictionary that encodes vocabulary """ vocab = Counter() for line in fobj: for word in line.split(): vocab[word] += 1 return vocab def update_pair_statistics(pair, changed, stats, indices): """Minimally update the indices and frequency of symbol pairs if we merge a pair of symbols, only pairs that overlap with occurrences of this pair are affected, and need to be updated. """ stats[pair] = 0 indices[pair] = defaultdict(int) first, second = pair new_pair = first+second for j, word, old_word, freq in changed: # find all instances of pair, and update frequency/indices around it i = 0 while True: try: i = old_word.index(first, i) except ValueError: break if i < len(old_word)-1 and old_word[i+1] == second: if i: prev = old_word[i-1:i+1] stats[prev] -= freq indices[prev][j] -= 1 if i < len(old_word)-2: # don't double-count consecutive pairs if old_word[i+2] != first or i >= len(old_word)-3 or old_word[i+3] != second: nex = old_word[i+1:i+3] stats[nex] -= freq indices[nex][j] -= 1 i += 2 else: i += 1 i = 0 while True: try: i = word.index(new_pair, i) except ValueError: break if i: prev = word[i-1:i+1] stats[prev] += freq indices[prev][j] += 1 # don't double-count consecutive pairs if i < len(word)-1 and word[i+1] != new_pair: nex = word[i:i+2] stats[nex] += freq indices[nex][j] += 1 i += 1 def get_pair_statistics(vocab): """Count frequency of all symbol pairs, and create index""" # data structure of pair frequencies stats = defaultdict(int) #index from pairs to words indices = defaultdict(lambda: defaultdict(int)) for i, (word, freq) in enumerate(vocab): prev_char = word[0] for char in word[1:]: stats[prev_char, char] += freq indices[prev_char, char][i] += 1 prev_char = char return stats, indices def replace_pair(pair, vocab, indices): """Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'""" first, second = pair pair_str = ''.join(pair) changes = [] pattern = re.compile(r'(?',) ,y) for (x,y) in vocab.items()]) sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) stats, indices = get_pair_statistics(sorted_vocab) big_stats = copy.deepcopy(stats) # threshold is inspired by Zipfian assumption, but should only affect speed threshold = max(stats.values()) / 10 for i in range(args.symbols): most_frequent = max(stats, key=stats.get) # we probably missed the best pair because of pruning; go back to full statistics if i and stats[most_frequent] < threshold: prune_stats(stats, big_stats, threshold) stats = copy.deepcopy(big_stats) most_frequent = max(stats, key=stats.get) # threshold is inspired by Zipfian assumption, but should only affect speed threshold = stats[most_frequent] * i/(i+10000.0) prune_stats(stats, big_stats, threshold) if stats[most_frequent] < 2: sys.stderr.write('no pair has frequency > 1. Stopping\n') break sys.stderr.write('pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format(i, most_frequent[0], most_frequent[1], stats[most_frequent])) args.output.write('{0} {1}\n'.format(*most_frequent)) changes = replace_pair(most_frequent, sorted_vocab, indices) update_pair_statistics(most_frequent, changes, stats, indices) stats[most_frequent] = 0 if not i % 100: prune_stats(stats, big_stats, threshold)