Add loading from HuggingFace Hub

Summary: Add loading from HuggingFace Hub. Revised from and to replace D32697723 (accepted).

Reviewed By: pipibjc, dianaml0

Differential Revision: D32964041

fbshipit-source-id: 39676aa0ecb10454ae76b70968d5abe96ab6da54
This commit is contained in:
Changhan Wang 2021-12-10 16:54:15 -08:00 committed by Facebook GitHub Bot
parent ce961a9fd2
commit 8548f1d401
2 changed files with 58 additions and 0 deletions

View File

@ -27,6 +27,8 @@ from fairseq.file_io import PathManager
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig, open_dict, OmegaConf
from pathlib import Path
logger = logging.getLogger(__name__)
@ -483,6 +485,34 @@ def load_model_ensemble_and_task(
return ensemble, cfg, task
def load_model_ensemble_and_task_from_hf_hub(
model_id,
cache_dir: Optional[str] = None,
arg_overrides: Optional[Dict[str, Any]] = None,
**kwargs: Any,
):
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"You need to install huggingface_hub to use `load_from_hf_hub`. "
"See https://pypi.org/project/huggingface-hub/ for installation."
)
library_name = "fairseq"
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
cache_dir = snapshot_download(
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
)
_arg_overrides = arg_overrides or {}
_arg_overrides["data"] = cache_dir
return load_model_ensemble_and_task(
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
arg_overrides=_arg_overrides
)
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
"""Retrieves all checkpoints found in `path` directory.

28
tests/test_hf_hub.py Normal file
View File

@ -0,0 +1,28 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
try:
import huggingface_hub
except ImportError:
huggingface_hub = None
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
@unittest.skipIf(not huggingface_hub, "Requires huggingface_hub install")
class TestHuggingFaceHub(unittest.TestCase):
@torch.no_grad()
def test_hf_fastspeech2(self):
hf_model_id = "facebook/fastspeech2-en-ljspeech"
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(hf_model_id)
self.assertTrue(len(models) > 0)
if __name__ == "__main__":
unittest.main()