Make infer compatible with voicebox inference pipeline

This commit is contained in:
lematt1991 2023-09-25 17:02:46 +00:00
parent 272c4c5197
commit fdca11448e
2 changed files with 10 additions and 6 deletions

View File

@ -45,6 +45,7 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
config_path = Path(__file__).resolve().parent / "conf"
config_path = str(config_path)
@dataclass

View File

@ -323,16 +323,19 @@ def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
# omegaconf version that supports object flags, or when we migrate all existing models
from omegaconf import __version__ as oc_version
from omegaconf import _utils
old_primitive = _utils.is_primitive_type
_utils.is_primitive_type = lambda _: True
if oc_version < "2.2":
old_primitive = _utils.is_primitive_type
_utils.is_primitive_type = lambda _: True
state["cfg"] = OmegaConf.create(state["cfg"])
_utils.is_primitive_type = old_primitive
OmegaConf.set_struct(state["cfg"], True)
state["cfg"] = OmegaConf.create(state["cfg"])
_utils.is_primitive_type = old_primitive
OmegaConf.set_struct(state["cfg"], True)
else:
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
if arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides)