FairseqEncoderModel

Summary: Base class for encoder-only models. Some models doesn't have decoder part.

Reviewed By: myleott

Differential Revision: D14413406

fbshipit-source-id: f36473b91dcf3c835fd6d50e2eb6002afa75f11a
This commit is contained in:
Dmytro Okhonko 2019-03-12 15:08:47 -07:00 committed by Facebook Github Bot
parent 7fc9a3be80
commit 9e1c880fbe
2 changed files with 41 additions and 0 deletions

View File

@ -17,6 +17,7 @@ from .fairseq_model import (
FairseqModel, # noqa: F401
FairseqMultiModel, # noqa: F401
FairseqLanguageModel, # noqa: F401
FairseqEncoderModel, # noqa: F401
)
from .composite_encoder import CompositeEncoder # noqa: F401

View File

@ -297,3 +297,43 @@ class FairseqLanguageModel(BaseFairseqModel):
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models.
Args:
encoder (FairseqEncoder): the encoder
"""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
assert isinstance(self.encoder, FairseqEncoder)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate logits.
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the encoder's output, typically of shape `(batch, seq_len, vocab)`
"""
return self.encoder(src_tokens, src_lengths)
def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()
@property
def supported_targets(self):
return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()