
237 lines
7.4 KiB
Raw Normal View History

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
import torch
Conv lm implementation This implements convolutional language model from https://arxiv.org/pdf/1612.08083.pdf There are 3 modes for constructing batches: - token block: fill each sample with a specified number of tokens without regard for sentence delimiters - this is what was used for training in the paper - complete: fill each sample with a specified number of tokens but make sure it contains only complete sentences (i.e. if next sentence goes over token block limit, move it to the next sample) - this was used for evaluation in the paper - eos: one sentence per sample (skip blank lines) some results: GCNN-13 - GBW - 37.46 GCNN-14B - GBW - 33.88 GCNN-8 - Wiki103 - 43.76 GCNN-14 - Wiki103 - 35.66 train: python train.py /private/home/abaevski/data/wiki103 --save-dir /tmp --fp16 --max-epoch 35 --save-interval 1 --save-interval-updates 1000 --keep-interval-updates 25 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 --decoder-embed-dim 280 --decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion cross_entropy --max-tokens 1024 --max-target-positions 1024 --seed 1 --log-format json --log-interval 500 eval: python eval_lm.py ~abaevski/data/wiki103 --path '/checkpoint02/abaevski/2018-04-27/lm_wiki.fp16.mxup300000.fconv.adam.lrs=reduce_lr_on_plateau.emb280.layers(850,6)*3+(850,1)+(850,5)*4+(850,1)+(850,4)*3+(1024,1)+(2048,4).lr0.0005.clp0.1.drp0.3.wd0.0.crt=cross_entropy.mxtk2048.smptk256.seed1.ngpu8/checkpoint_last.pt'
2018-05-25 16:43:37 +03:00
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
from fairseq.models import (
from fairseq.tasks import FairseqTask
def dummy_dictionary(vocab_size, prefix='token_'):
d = Dictionary()
for i in range(vocab_size):
token = prefix + str(i)
2018-04-24 21:46:54 +03:00
d.finalize(padding_factor=1) # don't add extra padding symbols
return d
def dummy_dataloader(
if batch_size is None:
batch_size = len(samples)
# add any missing data to samples
for i, sample in enumerate(samples):
if 'id' not in sample:
sample['id'] = i
# create dataloader
dataset = TestDataset(samples)
dataloader = torch.utils.data.DataLoader(
collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
return iter(dataloader)
def sequence_generator_setup():
# construct dummy dictionary
d = dummy_dictionary(vocab_size=2)
eos = d.eos()
w1 = 4
w2 = 5
# construct source data
src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
unk = 0.
args.beam_probs = [
# step 0:
# eos w1 w2
# sentence 1:
[0.0, unk, 0.9, 0.1], # beam 1
[0.0, unk, 0.9, 0.1], # beam 2
# sentence 2:
[0.0, unk, 0.7, 0.3],
[0.0, unk, 0.7, 0.3],
# step 1:
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[0.0, unk, 0.9, 0.1], # w2: 0.1
# sentence 2:
[0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[0.00, unk, 0.10, 0.9], # w2: 0.3
# step 2:
# eos w1 w2 prefix
# sentence 1:
[0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
[0.6, unk, 0.2, 0.2], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[0.60, unk, 0.4, 0.00], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
# step 3:
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[1.0, unk, 0.0, 0.0], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[0.1, unk, 0.5, 0.4], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[1.0, unk, 0.0, 0.0], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
task = TestTranslationTask.setup_task(args, d, d)
model = task.build_model(args)
tgt_dict = task.target_dictionary
return tgt_dict, w1, w2, src_tokens, src_lengths, model
class TestDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
self.sizes = None
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
class TestTranslationTask(FairseqTask):
def __init__(self, args, src_dict, tgt_dict, model):
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.model = model
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
return cls(args, src_dict, tgt_dict, model)
def build_model(self, args):
return TestModel.build_model(args, self)
def source_dictionary(self):
return self.src_dict
def target_dictionary(self):
return self.tgt_dict
class TestModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def build_model(cls, args, task):
encoder = TestEncoder(args, task.source_dictionary)
decoder = TestIncrementalDecoder(args, task.target_dictionary)
return cls(encoder, decoder)
class TestEncoder(FairseqEncoder):
def __init__(self, args, dictionary):
self.args = args
def forward(self, src_tokens, src_lengths=None, **kwargs):
return src_tokens
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out.index_select(0, new_order)
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):
assert hasattr(args, 'beam_probs') or hasattr(args, 'probs')
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary)
src_len = encoder_out.size(1)
tgt_len = prev_output_tokens.size(1)
# determine number of steps
if incremental_state is not None:
# cache step number
step = utils.get_incremental_state(self, incremental_state, 'step')
if step is None:
step = 0
utils.set_incremental_state(self, incremental_state, 'step', step + 1)
steps = [step]
steps = list(range(tgt_len))
# define output in terms of raw probs
if hasattr(self.args, 'probs'):
assert self.args.probs.dim() == 3, \
'expected probs to have size bsz*steps*vocab'
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, tgt_len, src_len)
dev = prev_output_tokens.device
return probs.to(dev), attn.to(dev)
Conv lm implementation This implements convolutional language model from https://arxiv.org/pdf/1612.08083.pdf There are 3 modes for constructing batches: - token block: fill each sample with a specified number of tokens without regard for sentence delimiters - this is what was used for training in the paper - complete: fill each sample with a specified number of tokens but make sure it contains only complete sentences (i.e. if next sentence goes over token block limit, move it to the next sample) - this was used for evaluation in the paper - eos: one sentence per sample (skip blank lines) some results: GCNN-13 - GBW - 37.46 GCNN-14B - GBW - 33.88 GCNN-8 - Wiki103 - 43.76 GCNN-14 - Wiki103 - 35.66 train: python train.py /private/home/abaevski/data/wiki103 --save-dir /tmp --fp16 --max-epoch 35 --save-interval 1 --save-interval-updates 1000 --keep-interval-updates 25 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 --decoder-embed-dim 280 --decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion cross_entropy --max-tokens 1024 --max-target-positions 1024 --seed 1 --log-format json --log-interval 500 eval: python eval_lm.py ~abaevski/data/wiki103 --path '/checkpoint02/abaevski/2018-04-27/lm_wiki.fp16.mxup300000.fconv.adam.lrs=reduce_lr_on_plateau.emb280.layers(850,6)*3+(850,1)+(850,5)*4+(850,1)+(850,4)*3+(1024,1)+(2048,4).lr0.0005.clp0.1.drp0.3.wd0.0.crt=cross_entropy.mxtk2048.smptk256.seed1.ngpu8/checkpoint_last.pt'
2018-05-25 16:43:37 +03:00
def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly
probs = net_output[0]
if log_probs:
return probs.log()
return probs
def max_positions(self):
return self.args.max_decoder_positions