diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index a5be928f0..4f74fb025 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -27,6 +27,8 @@ from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder from omegaconf import DictConfig, open_dict, OmegaConf +from pathlib import Path + logger = logging.getLogger(__name__) @@ -483,6 +485,34 @@ def load_model_ensemble_and_task( return ensemble, cfg, task +def load_model_ensemble_and_task_from_hf_hub( + model_id, + cache_dir: Optional[str] = None, + arg_overrides: Optional[Dict[str, Any]] = None, + **kwargs: Any, +): + try: + from huggingface_hub import snapshot_download + except ImportError: + raise ImportError( + "You need to install huggingface_hub to use `load_from_hf_hub`. " + "See https://pypi.org/project/huggingface-hub/ for installation." + ) + + library_name = "fairseq" + cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix() + cache_dir = snapshot_download( + model_id, cache_dir=cache_dir, library_name=library_name, **kwargs + ) + + _arg_overrides = arg_overrides or {} + _arg_overrides["data"] = cache_dir + return load_model_ensemble_and_task( + [p.as_posix() for p in Path(cache_dir).glob("*.pt")], + arg_overrides=_arg_overrides + ) + + def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): """Retrieves all checkpoints found in `path` directory. diff --git a/tests/test_hf_hub.py b/tests/test_hf_hub.py new file mode 100644 index 000000000..f39d1fa8e --- /dev/null +++ b/tests/test_hf_hub.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +try: + import huggingface_hub +except ImportError: + huggingface_hub = None + +from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub + + +@unittest.skipIf(not huggingface_hub, "Requires huggingface_hub install") +class TestHuggingFaceHub(unittest.TestCase): + @torch.no_grad() + def test_hf_fastspeech2(self): + hf_model_id = "facebook/fastspeech2-en-ljspeech" + models, cfg, task = load_model_ensemble_and_task_from_hf_hub(hf_model_id) + self.assertTrue(len(models) > 0) + + +if __name__ == "__main__": + unittest.main()