character token embeddings for word level predictions

This commit is contained in:
Alexei Baevski 2018-07-28 01:50:50 -07:00 committed by Myle Ott
parent 616afddd89
commit 885e7ec9ec
5 changed files with 253 additions and 2 deletions

View File

@ -15,7 +15,8 @@ from fairseq import options
from fairseq import utils
from fairseq.modules import (
AdaptiveSoftmax, LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding
AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding
)
from . import (
@ -161,6 +162,15 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--character-embeddings', default=False, action='store_true',
help='if set, uses character embedding convolutions to produce token embeddings')
parser.add_argument('--character-filters', type=str, metavar='LIST',
default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]',
help='size of character embeddings')
parser.add_argument('--character-embedding-dim', type=int, metavar='N', default=4,
help='size of character embeddings')
parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2,
help='number of highway layers for character token embeddder')
@classmethod
def build_model(cls, args, task):
@ -174,7 +184,19 @@ class TransformerLanguageModel(FairseqLanguageModel):
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = args.tokens_per_sample
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad())
if args.character_embeddings:
if not hasattr(args, 'char_embedder_highway_layers'):
args.char_embedder_highway_layers = 0
if not hasattr(args, 'character_filters'):
args.character_filters = '[(1, 4), (2, 8), (3, 16), (4, 32), (5, 64)]'
embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters),
args.character_embedding_dim,
args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
return TransformerLanguageModel(decoder)

View File

@ -7,9 +7,11 @@
from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .grad_multiply import GradMultiply
from .highway import Highway
from .learned_positional_embedding import LearnedPositionalEmbedding
from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
@ -19,9 +21,11 @@ from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [
'AdaptiveSoftmax',
'BeamableMM',
'CharacterTokenEmbedder',
'ConvTBC',
'DownsampledMultiHeadAttention',
'GradMultiply',
'Highway',
'LearnedPositionalEmbedding',
'LinearizedConvolution',
'MultiheadAttention',

View File

@ -0,0 +1,126 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from typing import List, Tuple
from .highway import Highway
from fairseq.data import Dictionary
class CharacterTokenEmbedder(torch.nn.Module):
def __init__(
self,
vocab: Dictionary,
filters: List[Tuple[int, int]],
char_embed_dim: int,
word_embed_dim: int,
highway_layers: int,
max_char_len: int = 50,
):
super(CharacterTokenEmbedder, self).__init__()
self.embedding_dim = word_embed_dim
self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0)
self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim))
self.eos_idx, self.unk_idx = 0, 1
self.convolutions = nn.ModuleList()
for width, out_c in filters:
self.convolutions.append(
nn.Conv1d(char_embed_dim, out_c, kernel_size=width)
)
final_dim = sum(f[1] for f in filters)
self.highway = Highway(final_dim, highway_layers)
self.projection = nn.Linear(final_dim, word_embed_dim)
self.set_vocab(vocab, max_char_len)
self.reset_parameters()
def set_vocab(self, vocab, max_char_len):
word_to_char = torch.LongTensor(len(vocab), max_char_len)
truncated = 0
for i in range(len(vocab)):
if i < vocab.nspecial:
char_idxs = [0] * max_char_len
else:
chars = vocab[i].encode()
# +1 for padding
char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars))
if len(char_idxs) > max_char_len:
truncated += 1
char_idxs = char_idxs[:max_char_len]
word_to_char[i] = torch.LongTensor(char_idxs)
if truncated > 0:
print('Truncated {} words longer than {} characters'.format(truncated, max_char_len))
self.vocab = vocab
self.word_to_char = word_to_char
@property
def padding_idx(self):
return self.vocab.pad()
def reset_parameters(self):
nn.init.xavier_normal_(self.char_embeddings.weight)
nn.init.xavier_normal_(self.symbol_embeddings)
nn.init.xavier_normal_(self.projection.weight)
nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.)
nn.init.constant_(self.projection.bias, 0.)
def forward(
self,
words: torch.Tensor,
):
self.word_to_char = self.word_to_char.type_as(words)
flat_words = words.view(-1)
word_embs = self._convolve(self.word_to_char[flat_words])
pads = flat_words.eq(self.vocab.pad())
if pads.any():
word_embs[pads] = 0
eos = flat_words.eq(self.vocab.eos())
if eos.any():
word_embs[eos] = self.symbol_embeddings[self.eos_idx]
unk = flat_words.eq(self.vocab.unk())
if unk.any():
word_embs[unk] = self.symbol_embeddings[self.unk_idx]
return word_embs.view(words.size() + (-1,))
def _convolve(
self,
char_idxs: torch.Tensor,
):
char_embs = self.char_embeddings(char_idxs)
char_embs = char_embs.transpose(1, 2) # BTC -> BCT
conv_result = []
for i, conv in enumerate(self.convolutions):
x = conv(char_embs)
x, _ = torch.max(x, -1)
x = F.relu(x)
conv_result.append(x)
conv_result = torch.cat(conv_result, dim=-1)
conv_result = self.highway(conv_result)
return self.projection(conv_result)

View File

@ -0,0 +1,55 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
from torch import nn
class Highway(torch.nn.Module):
"""
A `Highway layer <https://arxiv.org/abs/1505.00387>
Adopted from the AllenNLP implementation
"""
def __init__(
self,
input_dim: int,
num_layers: int = 1
):
super(Highway, self).__init__()
self.input_dim = input_dim
self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2)
for _ in range(num_layers)])
self.activation = nn.ReLU()
self.reset_parameters()
def reset_parameters(self):
for layer in self.layers:
# As per comment in AllenNLP:
# We should bias the highway layer to just carry its input forward. We do that by
# setting the bias on `B(x)` to be positive, because that means `g` will be biased to
# be high, so we will carry the input forward. The bias on `B(x)` is the second half
# of the bias vector in each Linear layer.
nn.init.constant_(layer.bias[self.input_dim:], 1)
nn.init.constant_(layer.bias[:self.input_dim], 0)
nn.init.xavier_normal_(layer.weight)
def forward(
self,
x: torch.Tensor
):
for layer in self.layers:
projection = layer(x)
proj_x, gate = projection.chunk(2, dim=-1)
proj_x = self.activation(proj_x)
gate = F.sigmoid(gate)
x = gate * x + (1 - gate) * proj_x
return x

View File

@ -0,0 +1,44 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import unittest
from fairseq.data import Dictionary
from fairseq.modules import CharacterTokenEmbedder
class TestCharacterTokenEmbedder(unittest.TestCase):
def test_character_token_embedder(self):
vocab = Dictionary()
vocab.add_symbol('hello')
vocab.add_symbol('there')
embedder = CharacterTokenEmbedder(vocab, [(2, 16), (4, 32), (8, 64), (16, 2)], 64, 5)
test_sents = [['hello', 'unk', 'there'], ['there'], ['hello', 'there']]
max_len = max(len(s) for s in test_sents)
input = torch.LongTensor(len(test_sents), max_len + 2)
for i in range(len(test_sents)):
input[i][0] = vocab.eos()
for j in range(len(test_sents[i])):
input[i][j + 1] = vocab.index(test_sents[i][j])
input[i][j + 2] = vocab.eos()
embs = embedder(input)
assert embs.size() == (len(test_sents), max_len + 2, 5)
assert embs[0][0].equal(embs[1][0])
assert embs[0][0].equal(embs[0][-1])
assert embs[0][1].equal(embs[2][1])
assert embs[0][3].equal(embs[1][1])
embs.sum().backward()
assert embedder.char_embeddings.weight.grad is not None
if __name__ == '__main__':
unittest.main()