mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-10 15:08:53 +03:00
a48f235636
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1357 Reviewed By: alexeib Differential Revision: D24377772 fbshipit-source-id: 51581af041d42d62166b33a35a1a4228b1a76f0c
103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
#!/usr/bin/env python
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
|
|
class InputExample:
|
|
def __init__(self, paragraph, qa_list, label):
|
|
self.paragraph = paragraph
|
|
self.qa_list = qa_list
|
|
self.label = label
|
|
|
|
|
|
def get_examples(data_dir, set_type):
|
|
"""
|
|
Extract paragraph and question-answer list from each json file
|
|
"""
|
|
examples = []
|
|
|
|
levels = ["middle", "high"]
|
|
set_type_c = set_type.split("-")
|
|
if len(set_type_c) == 2:
|
|
levels = [set_type_c[1]]
|
|
set_type = set_type_c[0]
|
|
for level in levels:
|
|
cur_dir = os.path.join(data_dir, set_type, level)
|
|
for filename in os.listdir(cur_dir):
|
|
cur_path = os.path.join(cur_dir, filename)
|
|
with open(cur_path, "r") as f:
|
|
cur_data = json.load(f)
|
|
answers = cur_data["answers"]
|
|
options = cur_data["options"]
|
|
questions = cur_data["questions"]
|
|
context = cur_data["article"].replace("\n", " ")
|
|
context = re.sub(r"\s+", " ", context)
|
|
for i in range(len(answers)):
|
|
label = ord(answers[i]) - ord("A")
|
|
qa_list = []
|
|
question = questions[i]
|
|
for j in range(4):
|
|
option = options[i][j]
|
|
if "_" in question:
|
|
qa_cat = question.replace("_", option)
|
|
else:
|
|
qa_cat = " ".join([question, option])
|
|
qa_cat = re.sub(r"\s+", " ", qa_cat)
|
|
qa_list.append(qa_cat)
|
|
examples.append(InputExample(context, qa_list, label))
|
|
|
|
return examples
|
|
|
|
|
|
def main():
|
|
"""
|
|
Helper script to extract paragraphs questions and answers from RACE datasets.
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--input-dir",
|
|
help="input directory for downloaded RACE dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
help="output directory for extracted data",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
for set_type in ["train", "dev", "test-middle", "test-high"]:
|
|
examples = get_examples(args.input_dir, set_type)
|
|
qa_file_paths = [
|
|
os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
|
|
for i in range(4)
|
|
]
|
|
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
|
|
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
|
|
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
|
|
outf_context = open(outf_context_path, "w")
|
|
outf_label = open(outf_label_path, "w")
|
|
for example in examples:
|
|
outf_context.write(example.paragraph + "\n")
|
|
for i in range(4):
|
|
qa_files[i].write(example.qa_list[i] + "\n")
|
|
outf_label.write(str(example.label) + "\n")
|
|
|
|
for f in qa_files:
|
|
f.close()
|
|
outf_label.close()
|
|
outf_context.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|