Add vectorized implementation to wav2vec2 sample_negatives (#2683)

Summary:
- Updates for loop to vectorized implementation which speeds up
fairseq-hydra training by ~8-10%.
- The for loop penalty is not incurred with torch distributed
- fairseq-hydra train starts many more processes thread which probably
slows this down as well.

# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Improves wav2vec2 pretraining speed with fairseq-hydra train

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2683

Reviewed By: arbabu123

Differential Revision: D32740855

Pulled By: alexeib

fbshipit-source-id: 1003a819679521ae1ae011cd79517e1035107e35
This commit is contained in:
Apoorv Vyas 2021-12-08 18:18:48 -08:00 committed by Facebook GitHub Bot
parent 0dfd6b6240
commit c620ed066f

View File

@ -464,8 +464,7 @@ class Wav2Vec2Model(BaseFairseqModel):
cross_neg_idxs[cross_neg_idxs >= tszs] += 1
if self.n_negatives > 0:
for i in range(1, bsz):
neg_idxs[i] += i * high
neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high)
else:
neg_idxs = cross_neg_idxs