diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 2507e661e..812ef5cd5 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -17,6 +17,7 @@ from .fairseq_model import ( FairseqModel, # noqa: F401 FairseqMultiModel, # noqa: F401 FairseqLanguageModel, # noqa: F401 + FairseqEncoderModel, # noqa: F401 ) from .composite_encoder import CompositeEncoder # noqa: F401 diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index e17c887ab..190eedabc 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -297,3 +297,43 @@ class FairseqLanguageModel(BaseFairseqModel): def remove_head(self): """Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed""" raise NotImplementedError() + + +class FairseqEncoderModel(BaseFairseqModel): + """Base class for encoder-only models. + + Args: + encoder (FairseqEncoder): the encoder + """ + + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + assert isinstance(self.encoder, FairseqEncoder) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + Run the forward pass for a encoder-only model. + + Feeds a batch of tokens through the encoder to generate logits. + + Args: + src_tokens (LongTensor): input tokens of shape `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + + Returns: + the encoder's output, typically of shape `(batch, seq_len, vocab)` + """ + return self.encoder(src_tokens, src_lengths) + + def max_positions(self): + """Maximum length supported by the model.""" + return self.encoder.max_positions() + + @property + def supported_targets(self): + return {'future'} + + def remove_head(self): + """Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed""" + raise NotImplementedError()