fairseq/tests/speech_recognition/test_vggtransformer.py
Dmytro Okhonko 72f9364cc6 Asr initial push (#810)
Summary:
Initial code for speech recognition task.
Right now only one ASR model added - https://arxiv.org/abs/1904.11660

unit test testing:
python -m unittest discover tests

also run model training with this code and obtained
5.0 test_clean | 13.4 test_other
on librispeech with pytorch/audio features
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/810

Reviewed By: cpuhrsch

Differential Revision: D16706659

Pulled By: okhonko

fbshipit-source-id: 89a5f9883e50bc0e548234287aa0ea73f7402514
2019-08-08 02:46:12 -07:00

136 lines
4.5 KiB
Python

#!/usr/bin/env python3
# import models/encoder/decoder to be tested
from examples.speech_recognition.models.vggtransformer import (
TransformerDecoder,
VGGTransformerEncoder,
VGGTransformerModel,
vggtransformer_1,
vggtransformer_2,
vggtransformer_base,
)
# import base test class
from .asr_test_base import (
DEFAULT_TEST_VOCAB_SIZE,
TestFairseqDecoderBase,
TestFairseqEncoderBase,
TestFairseqEncoderDecoderModelBase,
get_dummy_dictionary,
get_dummy_encoder_output,
get_dummy_input,
)
class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_1 use 14 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_1, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_2 use 16 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_2, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_base use 12 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_base, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerEncoderTest(TestFairseqEncoderBase):
def setUp(self):
super().setUp()
self.setUpInput(get_dummy_input(T=50, D=80, B=5))
def test_forward(self):
print("1. test standard vggtransformer")
self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80))
super().test_forward()
print("2. test vggtransformer with limited right context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(-1, 5)
)
)
super().test_forward()
print("3. test vggtransformer with limited left context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(5, -1)
)
)
super().test_forward()
print("4. test vggtransformer with limited right context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(-1, 12),
transformer_sampling=(2, 2),
)
)
super().test_forward()
print("5. test vggtransformer with windowed context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(12, 12),
transformer_sampling=(2, 2),
)
)
class TransformerDecoderTest(TestFairseqDecoderBase):
def setUp(self):
super().setUp()
dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE)
decoder = TransformerDecoder(dict)
dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256))
self.setUpDecoder(decoder)
self.setUpInput(dummy_encoder_output)
self.setUpPrevOutputTokens()