Updates to model API (#561)

Summary:
- `FairseqModel` -> `FairseqEncoderDecoderModel`
- add `FairseqDecoder.extract_features` and `FairseqDecoder.output_layer`
- `encoder_out_dict` -> `encoder_out`
- rm unused `remove_head` functions
- update docs
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/561

Differential Revision: D15271142

Pulled By: myleott

fbshipit-source-id: 8e8864e399336020f0271c780598e968ff51a264
This commit is contained in:
Myle Ott 2019-05-15 07:09:48 -07:00 committed by Facebook Github Bot
parent a0c5f9b860
commit dffb167449
16 changed files with 207 additions and 110 deletions

View File

@ -74,12 +74,18 @@ Adding new models
.. autoclass:: fairseq.models.BaseFairseqModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqModel
.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoderModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqLanguageModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqMultiModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoder
:members:
.. autoclass:: fairseq.models.CompositeEncoder

View File

@ -2,7 +2,7 @@ Modules
=======
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
.. automodule:: fairseq.modules
:members:

View File

@ -41,7 +41,7 @@ New plug-ins are *registered* through a set of ``@register`` function
decorators, for example::
@register_model('my_lstm')
class MyLSTM(FairseqModel):
class MyLSTM(FairseqEncoderDecoderModel):
(...)
Once registered, new plug-ins can be used with the existing :ref:`Command-line

View File

@ -2,9 +2,9 @@ Tutorial: Simple LSTM
=====================
In this tutorial we will extend fairseq by adding a new
:class:`~fairseq.models.FairseqModel` that encodes a source sentence with an
LSTM and then passes the final hidden state to a second LSTM that decodes the
target sentence (without attention).
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
sentence with an LSTM and then passes the final hidden state to a second LSTM
that decodes the target sentence (without attention).
This tutorial covers:
@ -233,18 +233,18 @@ Once the model is registered we'll be able to use it with the existing
All registered models must implement the
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
models (i.e., any model with a single Encoder and Decoder), we can instead
implement the :class:`~fairseq.models.FairseqModel` interface.
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
Create a small wrapper class in the same file and register it in fairseq with
the name ``'simple_lstm'``::
from fairseq.models import FairseqModel, register_model
from fairseq.models import FairseqEncoderDecoderModel, register_model
# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.
@register_model('simple_lstm')
class SimpleLSTMModel(FairseqModel):
class SimpleLSTMModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
@ -308,7 +308,7 @@ the name ``'simple_lstm'``::
# We could override the ``forward()`` if we wanted more control over how
# the encoder and decoder interact, but it's not necessary for this
# tutorial since we can inherit the default implementation provided by
# the FairseqModel base class, which looks like:
# the FairseqEncoderDecoderModel base class, which looks like:
#
# def forward(self, src_tokens, src_lengths, prev_output_tokens):
# encoder_out = self.encoder(src_tokens, src_lengths)

View File

@ -14,10 +14,11 @@ from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import (
BaseFairseqModel,
FairseqEncoderModel,
FairseqEncoderDecoderModel,
FairseqLanguageModel,
FairseqModel,
FairseqMultiModel,
FairseqLanguageModel,
FairseqEncoderModel,
)
from .composite_encoder import CompositeEncoder
@ -30,6 +31,7 @@ __all__ = [
'DistributedFairseqModel',
'FairseqDecoder',
'FairseqEncoder',
'FairseqEncoderDecoderModel',
'FairseqEncoderModel',
'FairseqIncrementalDecoder',
'FairseqLanguageModel',
@ -56,12 +58,13 @@ def register_model(name):
For example::
@register_model('lstm')
class LSTM(FairseqModel):
class LSTM(FairseqEncoderDecoderModel):
(...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqModel` for sequence-to-sequence
tasks or :class:`FairseqLanguageModel` for language modeling tasks.
Typically you will extend :class:`FairseqEncoderDecoderModel` for
sequence-to-sequence tasks or :class:`FairseqLanguageModel` for
language modeling tasks.
Args:
name (str): the name of the model

View File

@ -18,25 +18,40 @@ class FairseqDecoder(nn.Module):
self.dictionary = dictionary
self.onnx_trace = False
def forward(self, prev_output_tokens, encoder_out):
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape
`(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape
`(batch, tgt_len, src_len)`
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
x = self.output_layer(x)
return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def output_layer(self, features, **kwargs):
"""
Project features to the default output size, e.g., vocabulary size.
Args:
features (Tensor): features returned by *extract_features*.
"""
raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
@ -63,3 +78,6 @@ class FairseqDecoder(nn.Module):
def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
def prepare_for_onnx_export_(self):
self.onnx_trace = True

View File

@ -15,7 +15,7 @@ class FairseqEncoder(nn.Module):
super().__init__()
self.dictionary = dictionary
def forward(self, src_tokens, src_lengths):
def forward(self, src_tokens, src_lengths=None, **kwargs):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape

View File

@ -12,8 +12,8 @@ class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model
only receives a single timestep of input corresponding to the immediately
previous output token (for input feeding) and must produce the next output
only receives a single timestep of input corresponding to the previous
output token (for input feeding) and must produce the next output
*incrementally*. Thus the model must cache any long-term state that is
needed about the sequence, e.g., hidden states, convolutional states, etc.
@ -33,22 +33,29 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
incremental_state (dict, optional): dictionary used for storing
state during :ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError

View File

@ -4,6 +4,9 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Base classes for various fairseq models.
"""
from typing import Dict, List, Optional
@ -11,6 +14,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.models import FairseqDecoder, FairseqEncoder
@ -30,7 +34,7 @@ class BaseFairseqModel(nn.Module):
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError('FairseqModels must implement the build_model method')
raise NotImplementedError('Model must implement the build_model method')
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
@ -48,14 +52,14 @@ class BaseFairseqModel(nn.Module):
return F.softmax(logits, dim=-1)
raise NotImplementedError
def extract_features(self, *args, **kwargs):
"""Similar to *forward* but only return features."""
return self(*args, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return None
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
@ -139,7 +143,7 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_prepare_for_onnx_export_)
class FairseqModel(BaseFairseqModel):
class FairseqEncoderDecoderModel(BaseFairseqModel):
"""Base class for encoder-decoder models.
Args:
@ -155,7 +159,7 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Run the forward pass for an encoder-decoder model.
@ -174,19 +178,54 @@ class FairseqModel(BaseFairseqModel):
`(batch, tgt_len)`, for input feeding/teacher forcing
Returns:
the decoder's output, typically of shape `(batch, tgt_len, vocab)`
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
features = self.decoder.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return features
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions())
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
class FairseqModel(FairseqEncoderDecoderModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
utils.deprecation_warning(
'FairseqModel is deprecated, please use FairseqEncoderDecoderModel '
'or BaseFairseqModel instead',
stacklevel=4,
)
class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models."""
def __init__(self, encoders, decoders):
super().__init__()
assert encoders.keys() == decoders.keys()
@ -232,11 +271,13 @@ class FairseqMultiModel(BaseFairseqModel):
shared_dict, embed_dim, pretrained_embed_path
)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
decoder_outs = {}
for key in self.keys:
encoder_out = self.models[key].encoder(src_tokens, src_lengths)
decoder_outs[key] = self.models[key].decoder(prev_output_tokens, encoder_out)
encoder_out = self.models[key].encoder(src_tokens, src_lengths, **kwargs)
decoder_outs[key] = self.models[key].decoder(
prev_output_tokens, encoder_out, **kwargs,
)
return decoder_outs
def max_positions(self):
@ -271,7 +312,7 @@ class FairseqLanguageModel(BaseFairseqModel):
self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths):
def forward(self, src_tokens, **kwargs):
"""
Run the forward pass for a decoder-only model.
@ -283,22 +324,39 @@ class FairseqLanguageModel(BaseFairseqModel):
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the decoder's output, typically of shape `(batch, seq_len, vocab)`
tuple:
- the decoder's output of shape `(batch, seq_len, vocab)`
- a dictionary with any model-specific outputs
"""
return self.decoder(src_tokens)
return self.decoder(src_tokens, **kwargs)
def extract_features(self, src_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, seq_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
return self.decoder.extract_features(src_tokens, **kwargs)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return self.decoder.max_positions()
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
@property
def supported_targets(self):
return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models.
@ -316,14 +374,14 @@ class FairseqEncoderModel(BaseFairseqModel):
"""
Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate logits.
Feeds a batch of tokens through the encoder to generate features.
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the encoder's output, typically of shape `(batch, seq_len, vocab)`
the encoder's output, typically of shape `(batch, src_len, features)`
"""
return self.encoder(src_tokens, src_lengths, **kwargs)
@ -341,11 +399,3 @@ class FairseqEncoderModel(BaseFairseqModel):
def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()
@property
def supported_targets(self):
return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()

View File

@ -14,7 +14,7 @@ from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
@ -25,7 +25,7 @@ from fairseq.modules import (
@register_model('fconv')
class FConvModel(FairseqModel):
class FConvModel(FairseqEncoderDecoderModel):
"""
A fully convolutional model, i.e. a convolutional encoder and a
convolutional decoder, as described in `"Convolutional Sequence to Sequence
@ -406,10 +406,10 @@ class FConvDecoder(FairseqIncrementalDecoder):
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None):
if encoder_out_dict is not None:
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
if encoder_out is not None:
encoder_padding_mask = encoder_out['encoder_padding_mask']
encoder_out = encoder_out['encoder_out']
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)

View File

@ -16,7 +16,7 @@ from fairseq.models import (
CompositeEncoder,
FairseqDecoder,
FairseqEncoder,
FairseqModel,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
@ -30,7 +30,7 @@ from fairseq.modules import (
@register_model('fconv_self_att')
class FConvModelSelfAtt(FairseqModel):
class FConvModelSelfAtt(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder, pretrained_encoder=None):
super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
@ -371,9 +371,9 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict):
encoder_out = encoder_out_dict['encoder']['encoder_out']
trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None
def forward(self, prev_output_tokens, encoder_out):
trained_encoder_out = encoder_out['pretrained'] if self.pretrained else None
encoder_out = encoder_out['encoder']['encoder_out']
encoder_a, encoder_b = self._split_encoder_out(encoder_out)

View File

@ -15,7 +15,7 @@ from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
@ -31,7 +31,7 @@ from fairseq.modules import (
@register_model('lightconv')
class LightConvModel(FairseqModel):
class LightConvModel(FairseqEncoderDecoderModel):
"""
LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019)
<https://openreview.net/pdf?id=SkVhlh09tX>`_.
@ -213,13 +213,11 @@ class LightConvEncoder(FairseqEncoder):
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths):
def forward(self, src_tokens, **unused):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
Returns:
dict:

View File

@ -13,7 +13,7 @@ from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
@ -21,7 +21,7 @@ from fairseq.modules import AdaptiveSoftmax
@register_model('lstm')
class LSTMModel(FairseqModel):
class LSTMModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@ -356,9 +356,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
encoder_padding_mask = encoder_out['encoder_padding_mask']
encoder_out = encoder_out['encoder_out']
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]

View File

@ -15,7 +15,7 @@ from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
@ -29,7 +29,7 @@ from fairseq.modules import (
@register_model('transformer')
class TransformerModel(FairseqModel):
class TransformerModel(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
@ -298,7 +298,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
output_embed_dim = args.decoder_output_dim
self.output_embed_dim = args.decoder_output_dim
padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
@ -321,13 +321,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None
self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
output_embed_dim,
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
@ -335,14 +335,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
@ -354,10 +354,21 @@ class TransformerDecoder(FairseqIncrementalDecoder):
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state)
x = self.output_layer(x)
return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
# embed positions
positions = self.embed_positions(
@ -406,14 +417,18 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {'attn': attn, 'inner_states': inner_states}
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
return F.linear(features, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, {'attn': attn, 'inner_states': inner_states}
return F.linear(features, self.embed_out)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""

View File

@ -83,10 +83,10 @@ class LanguageModelingTask(FairseqTask):
help='prepend beginning of sentence token (<s>)')
# fmt: on
def __init__(self, args, dictionary, output_dictionary, targets=None):
def __init__(self, args, dictionary, output_dictionary=None, targets=None):
super().__init__(args)
self.dictionary = dictionary
self.output_dictionary = output_dictionary
self.output_dictionary = output_dictionary or dictionary
if targets is None:
targets = ['future']

View File

@ -13,8 +13,8 @@ from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
FairseqModel,
)
from fairseq.tasks import FairseqTask
@ -154,7 +154,7 @@ class TestTranslationTask(FairseqTask):
return self.tgt_dict
class TestModel(FairseqModel):
class TestModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@ -170,7 +170,7 @@ class TestEncoder(FairseqEncoder):
super().__init__(dictionary)
self.args = args
def forward(self, src_tokens, src_lengths):
def forward(self, src_tokens, src_lengths=None, **kwargs):
return src_tokens
def reorder_encoder_out(self, encoder_out, new_order):
@ -184,7 +184,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bbsz = prev_output_tokens.size(0)