diff --git a/hubconf.py b/hubconf.py index 1eb25f870..992c259fa 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,6 +5,8 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import functools + from fairseq.models import MODEL_REGISTRY @@ -18,5 +20,11 @@ dependencies = [ ] -for model, cls in MODEL_REGISTRY.items(): - globals()[model] = cls.from_pretrained +for model_type, _cls in MODEL_REGISTRY.items(): + for model_name in _cls.hub_models().keys(): + globals()[model_name] = functools.partial( + _cls.from_pretrained, + model_name_or_path=model_name, + ) + # to simplify the interface we only expose named models + #globals()[model_type] = _cls.from_pretrained