Update HF GPT2 (#1058)

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1058

Differential Revision: D20222225

Pulled By: myleott

fbshipit-source-id: 37fd5b66b1dd5518086c156b61dc8e832b6f20d7
This commit is contained in:
Myle Ott 2020-03-03 16:04:20 -08:00 committed by Facebook Github Bot
parent 244835d811
commit 6a4ad3327a
7 changed files with 8 additions and 10 deletions

3
.gitmodules vendored
View File

@ -1,3 +1,4 @@
[submodule "fairseq/models/huggingface/transformers"]
path = fairseq/models/huggingface/transformers
url = https://github.com/huggingface/transformers.git
url = https://github.com/myleott/transformers.git
branch = fairseq

View File

@ -77,7 +77,7 @@ class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder):
config = GPT2Config(
vocab_size=len(task.target_dictionary),
n_positions=args.max_target_positions,
n_positions=args.max_target_positions + 1,
n_ctx=args.max_target_positions,
n_embd=args.embed_dim,
n_layer=args.num_layers,
@ -138,7 +138,7 @@ class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder):
return last_hidden_states
def max_positions(self):
return self.model.config.n_positions
return self.model.config.n_positions - 1
@register_model_architecture('hf_gpt2', 'hf_gpt2')

@ -1 +1 @@
Subproject commit d426b58b9e32a2ffc8c8a1196143270e22054a46
Subproject commit a8cad0cc839e3cdf980a0edabac6f18b7ee4d9e4

View File

@ -18,7 +18,6 @@ try:
def forward(self, x):
return super().forward(x)
except ImportError:
has_fused_layernorm = False

View File

@ -572,7 +572,7 @@ class EnsembleModel(torch.nn.Module):
decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.:
decoder_out[0].div_(temperature)
attn = decoder_out[1]
attn = decoder_out[1] if len(decoder_out) > 1 else None
if type(attn) is dict:
attn = attn.get('attn', None)
if type(attn) is list:
@ -689,7 +689,7 @@ class EnsembleModelWithAlignment(EnsembleModel):
decoder_out[0] = decoder_out[0][:, -1:, :]
if temperature != 1.:
decoder_out[0].div_(temperature)
attn = decoder_out[1]
attn = decoder_out[1] if len(decoder_out) > 1 else None
if type(attn) is dict:
attn = attn.get('attn', None)
if type(attn) is list:

View File

@ -54,7 +54,7 @@ class SequenceScorer(object):
for model in models:
model.eval()
decoder_out = model(**net_input)
attn = decoder_out[1]
attn = decoder_out[1] if len(decoder_out) > 1 else None
if type(attn) is dict:
attn = attn.get('attn', None)

View File

@ -312,8 +312,6 @@ def cli_main(modify_parser=None):
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_rank = None # set based on device id
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
logger.info('NOTE: you may get faster training with: --ddp-backend=no_c10d')
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),