Update align_and_segment.py (#5317)

Fix MMS alignment code
This commit is contained in:
Vineel Pratap 2023-09-07 11:25:28 -07:00 committed by GitHub
parent 4db264940f
commit b5d89cddc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,13 +87,14 @@ def get_alignments(
blank = dictionary["<blank>"]
targets = torch.tensor(token_indices, dtype=torch.int32).to(DEVICE)
input_lengths = torch.tensor(emissions.shape[0])
target_lengths = torch.tensor(targets.shape[0])
input_lengths = torch.tensor(emissions.shape[0]).unsqueeze(-1)
target_lengths = torch.tensor(targets.shape[0]).unsqueeze(-1)
path, _ = F.forced_align(
emissions, targets, input_lengths, target_lengths, blank=blank
emissions.unsqueeze(0), targets.unsqueeze(0), input_lengths, target_lengths, blank=blank
)
path = path.to("cpu").tolist()
path = path.squeeze().to("cpu").tolist()
segments = merge_repeats(path, {v: k for k, v in dictionary.items()})
return segments, stride