Add max position params to speech recognition (#1783)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes https://github.com/pytorch/fairseq/issues/1782.

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1783

Reviewed By: okhonko

Differential Revision: D21663633

Pulled By: myleott

fbshipit-source-id: 5f3b4b7df83e27d866efb489daeffb3b38a66f38
This commit is contained in:
Marco Gaido 2020-06-23 06:46:50 -07:00 committed by Facebook GitHub Bot
parent d0ccc3e02e
commit a12c5c5de8

View File

@ -6,6 +6,7 @@
import json
import os
import re
import sys
import torch
from fairseq.data import Dictionary
@ -77,6 +78,10 @@ class SpeechRecognitionTask(FairseqTask):
parser.add_argument(
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
)
parser.add_argument('--max-source-positions', default=sys.maxsize, type=int, metavar='N',
help='max number of frames in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
def __init__(self, args, tgt_dict):
super().__init__(args)
@ -132,3 +137,7 @@ class SpeechRecognitionTask(FairseqTask):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
return None
def max_positions(self):
"""Return the max speech and sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)