mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-04 04:37:58 +03:00
Make infer compatible with voicebox inference pipeline
This commit is contained in:
parent
272c4c5197
commit
fdca11448e
@ -45,6 +45,7 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config_path = Path(__file__).resolve().parent / "conf"
|
||||
config_path = str(config_path)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -323,16 +323,19 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
||||
|
||||
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
||||
# omegaconf version that supports object flags, or when we migrate all existing models
|
||||
from omegaconf import __version__ as oc_version
|
||||
from omegaconf import _utils
|
||||
|
||||
old_primitive = _utils.is_primitive_type
|
||||
_utils.is_primitive_type = lambda _: True
|
||||
if oc_version < "2.2":
|
||||
old_primitive = _utils.is_primitive_type
|
||||
_utils.is_primitive_type = lambda _: True
|
||||
|
||||
state["cfg"] = OmegaConf.create(state["cfg"])
|
||||
|
||||
_utils.is_primitive_type = old_primitive
|
||||
OmegaConf.set_struct(state["cfg"], True)
|
||||
state["cfg"] = OmegaConf.create(state["cfg"])
|
||||
|
||||
_utils.is_primitive_type = old_primitive
|
||||
OmegaConf.set_struct(state["cfg"], True)
|
||||
else:
|
||||
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
|
||||
if arg_overrides is not None:
|
||||
overwrite_args_by_name(state["cfg"], arg_overrides)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user