mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-20 18:41:02 +03:00
Add vectorized implementation to wav2vec2 sample_negatives (#2683)
Summary: - Updates for loop to vectorized implementation which speeds up fairseq-hydra training by ~8-10%. - The for loop penalty is not incurred with torch distributed - fairseq-hydra train starts many more processes thread which probably slows this down as well. # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Improves wav2vec2 pretraining speed with fairseq-hydra train ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2683 Reviewed By: arbabu123 Differential Revision: D32740855 Pulled By: alexeib fbshipit-source-id: 1003a819679521ae1ae011cd79517e1035107e35
This commit is contained in:
parent
0dfd6b6240
commit
c620ed066f
@ -464,8 +464,7 @@ class Wav2Vec2Model(BaseFairseqModel):
|
||||
cross_neg_idxs[cross_neg_idxs >= tszs] += 1
|
||||
|
||||
if self.n_negatives > 0:
|
||||
for i in range(1, bsz):
|
||||
neg_idxs[i] += i * high
|
||||
neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high)
|
||||
else:
|
||||
neg_idxs = cross_neg_idxs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user