Add transcript option for asr-bleu (#4981)

This commit is contained in:
Xutai Ma 2023-02-08 23:16:22 -05:00 committed by GitHub
parent 214c0cbd6f
commit ad0e69cd99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,4 @@
import os
from typing import Dict, List
import sacrebleu
import pandas as pd
@ -37,10 +38,13 @@ def remove_tone(text):
def extract_audio_for_eval(audio_dirpath: str, audio_format: str):
if audio_format == "n_pred.wav":
"""
The assumption here is that pred_0.wav corresponds to the reference at line position 0 from the reference manifest
The assumption here is that 0_pred.wav corresponds to the reference at line position 0 from the reference manifest
"""
audio_list = []
audio_fp_list = glob((Path(audio_dirpath) / "*_pred.wav").as_posix())
audio_fp_list = sorted(
audio_fp_list, key=lambda x: int(os.path.basename(x).split("_")[0])
)
for i in range(len(audio_fp_list)):
try:
audio_fp = (Path(audio_dirpath) / f"{i}_pred.wav").as_posix()
@ -154,7 +158,7 @@ def run_asr_bleu(args):
print(bleu_score)
return bleu_score
return prediction_transcripts, bleu_score
def main():
@ -206,10 +210,16 @@ def main():
type=str,
help="If specified, the resulting BLEU score will be written to this file path as txt file",
)
parser.add_argument(
"--transcripts_path",
default=None,
type=str,
help="If specified, the predicted transcripts will be written to this path as a txt file.",
)
args = parser.parse_args()
bleu_score = run_asr_bleu(args)
prediction_transcripts, bleu_score = run_asr_bleu(args)
result_filename = f"{args.reference_format}_{args.lang}_bleu.txt"
if args.results_dirpath is not None:
if not Path(args.results_dirpath).exists():
@ -217,6 +227,11 @@ def main():
with open(Path(args.results_dirpath) / result_filename, "w") as f:
f.write(bleu_score.format(width=2))
if args.transcripts_path is not None:
with open(args.transcripts_path, "w") as f:
for transcript in prediction_transcripts:
f.write(transcript + "\n")
if __name__ == "__main__":
main()