fix xlsr checkpoint finetuning saving issues (#2013)

Summary:
fixes an issue with some old checkpoints that had deep nested namespaces containing choices enum - most prominently xlsr 53 checkpoint

fixes #3634

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2013

Reviewed By: xuqiantong

Differential Revision: D29511325

Pulled By: alexeib

fbshipit-source-id: 79df978afa7482b4ce3aaf7396e193626181aa17
This commit is contained in:
alexeib 2021-07-01 08:36:02 -07:00 committed by Facebook GitHub Bot
parent 9bee82e4a7
commit 096f492a22
2 changed files with 42 additions and 30 deletions

View File

@ -17,7 +17,7 @@ from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.configs import FairseqConfig
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose, initialize
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf import DictConfig, OmegaConf, open_dict, _utils
logger = logging.getLogger(__name__)
@ -341,6 +341,17 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
return overrides, deletes
class omegaconf_no_object_check:
def __init__(self):
self.old_is_primitive = _utils.is_primitive_type
def __enter__(self):
_utils.is_primitive_type = lambda _: True
def __exit__(self, type, value, traceback):
_utils.is_primitive_type = self.old_is_primitive
def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
"""Convert a flat argparse.Namespace to a structured DictConfig."""
@ -370,41 +381,40 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
# omegaconf version that supports object flags, or when we migrate all existing models
from omegaconf import _utils
old_primitive = _utils.is_primitive_type
_utils.is_primitive_type = lambda _: True
with omegaconf_no_object_check():
if cfg.task is None and getattr(args, "task", None):
cfg.task = Namespace(**vars(args))
from fairseq.tasks import TASK_REGISTRY
if cfg.task is None and getattr(args, "task", None):
cfg.task = Namespace(**vars(args))
from fairseq.tasks import TASK_REGISTRY
_set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
cfg.task._name = args.task
if cfg.model is None and getattr(args, "arch", None):
cfg.model = Namespace(**vars(args))
from fairseq.models import ARCH_MODEL_REGISTRY
_set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
cfg.task._name = args.task
if cfg.model is None and getattr(args, "arch", None):
cfg.model = Namespace(**vars(args))
from fairseq.models import ARCH_MODEL_REGISTRY
_set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
cfg.model._name = args.arch
if cfg.optimizer is None and getattr(args, "optimizer", None):
cfg.optimizer = Namespace(**vars(args))
from fairseq.optim import OPTIMIZER_REGISTRY
_set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
cfg.model._name = args.arch
if cfg.optimizer is None and getattr(args, "optimizer", None):
cfg.optimizer = Namespace(**vars(args))
from fairseq.optim import OPTIMIZER_REGISTRY
_set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
cfg.optimizer._name = args.optimizer
if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
cfg.lr_scheduler = Namespace(**vars(args))
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
_set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
cfg.optimizer._name = args.optimizer
if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
cfg.lr_scheduler = Namespace(**vars(args))
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
_set_legacy_defaults(
cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]
)
cfg.lr_scheduler._name = args.lr_scheduler
if cfg.criterion is None and getattr(args, "criterion", None):
cfg.criterion = Namespace(**vars(args))
from fairseq.criterions import CRITERION_REGISTRY
_set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler])
cfg.lr_scheduler._name = args.lr_scheduler
if cfg.criterion is None and getattr(args, "criterion", None):
cfg.criterion = Namespace(**vars(args))
from fairseq.criterions import CRITERION_REGISTRY
_set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
cfg.criterion._name = args.criterion
_set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
cfg.criterion._name = args.criterion
_utils.is_primitive_type = old_primitive
OmegaConf.set_struct(cfg, True)
return cfg

View File

@ -337,6 +337,8 @@ class Wav2VecEncoder(FairseqEncoder):
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
w2v_args.criterion = None
w2v_args.lr_scheduler = None
cfg.w2v_args = w2v_args
else:
state = None