fairseq/tests/test_hf_hub.py
Changhan Wang 8548f1d401 Add loading from HuggingFace Hub
Summary: Add loading from HuggingFace Hub. Revised from and to replace D32697723 (accepted).

Reviewed By: pipibjc, dianaml0

Differential Revision: D32964041

fbshipit-source-id: 39676aa0ecb10454ae76b70968d5abe96ab6da54
2021-12-10 16:55:12 -08:00

29 lines
795 B
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
try:
import huggingface_hub
except ImportError:
huggingface_hub = None
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
@unittest.skipIf(not huggingface_hub, "Requires huggingface_hub install")
class TestHuggingFaceHub(unittest.TestCase):
@torch.no_grad()
def test_hf_fastspeech2(self):
hf_model_id = "facebook/fastspeech2-en-ljspeech"
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(hf_model_id)
self.assertTrue(len(models) > 0)
if __name__ == "__main__":
unittest.main()