Fix a bug of the ['corpus_key'] of multi_corpus_dataset

The ['corpus_key'] in batch['net_input']['corpus_key'] provides false information. Fix this bug in multi_corpus_dataset.py file.
This commit is contained in:
Ziyang Ma 2023-10-30 11:06:37 +08:00 committed by GitHub
parent da8fb63088
commit 0c1eac03df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -215,6 +215,12 @@ class MultiCorpusDataset(FairseqDataset):
except Exception:
print(f"Collating failed for key {key}", flush=True)
raise
# map key to dataset index
corpus_key_list = []
for sample in samples:
_, key = self._map_index(sample["full_id"])
corpus_key_list.append(list(self.datasets.keys()).index(key))
batch["net_input"]["corpus_key"] = torch.Tensor(dataset_idx_list).long()
return batch
else:
# Subclasses may override __getitem__ to not specify full_id