Register weights as a non-persistent buffer of SinusoidalPositionalEmbedding (#5213)

This commit is contained in:
Yun Wang (Maigo) 2023-06-23 10:31:52 -07:00 committed by GitHub
parent a29952ce6d
commit 31fba013a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 15 additions and 49 deletions

View File

@ -441,7 +441,6 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
# fmt: off
if isinstance(module, TransformerEncoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor']
if isinstance(module, TransformerEncoderLayer):
for suffix in encoder_key_suffixes:
new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}']
@ -456,7 +455,6 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias']
if isinstance(module, TransformerDecoderEmbedding):
new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight']
new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor']
if isinstance(module, TransformerDecoderOutputLayer):
new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight']
# fmt: on
@ -741,14 +739,6 @@ class TransformerDecoder(FairseqDecoder):
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {

View File

@ -294,12 +294,6 @@ class MaskedLMEncoder(FairseqEncoder):
return self.max_positions
def upgrade_state_dict_named(self, state_dict, name):
if isinstance(
self.sentence_encoder.embed_positions, SinusoidalPositionalEmbedding
):
state_dict[
name + ".sentence_encoder.embed_positions._float_tensor"
] = torch.FloatTensor(1)
if not self.load_softmax:
for k in list(state_dict.keys()):
if (

View File

@ -399,14 +399,6 @@ class TransformerDecoderBase(FairseqIncrementalDecoder):
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"

View File

@ -305,14 +305,6 @@ class AugTransformerDecoderBase(TransformerDecoderBase):
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"

View File

@ -331,14 +331,6 @@ class TransformerEncoderBase(FairseqEncoder):
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(

View File

@ -9,7 +9,7 @@ from typing import Any, Optional
import torch
import torch.onnx.operators
from fairseq import utils
from torch import Tensor, nn
from torch import nn, Tensor
class SinusoidalPositionalEmbedding(nn.Module):
@ -22,16 +22,23 @@ class SinusoidalPositionalEmbedding(nn.Module):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx if padding_idx is not None else 0
self.weights = SinusoidalPositionalEmbedding.get_embedding(
self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
)
self.onnx_trace = False
self.register_buffer("_float_tensor", torch.FloatTensor(1))
), persistent=False)
self.max_positions = int(1e5)
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Ignore some deprecated keys that were used in older versions
deprecated_keys = ["weights", "_float_tensor"]
for key in deprecated_keys:
if prefix + key in state_dict:
del state_dict[prefix + key]
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
@ -68,12 +75,11 @@ class SinusoidalPositionalEmbedding(nn.Module):
bspair = torch.onnx.operators.shape_as_tensor(input)
bsz, seq_len = bspair[0], bspair[1]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
if max_pos > self.weights.size(0):
# expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
)
self.weights = self.weights.to(self._float_tensor)
).to(self.weights)
if incremental_state is not None:
# positions is the same for every token when decoding a single step