diff --git a/examples/data2vec/models/mae.py b/examples/data2vec/models/mae.py index 5101e070e..a3b5f72a4 100644 --- a/examples/data2vec/models/mae.py +++ b/examples/data2vec/models/mae.py @@ -21,7 +21,11 @@ from fairseq.dataclass import FairseqDataclass from fairseq.models import BaseFairseqModel, register_model from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer -from apex.normalization import FusedLayerNorm +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = nn.LayerNorm + import torch.nn.functional as F