Add scripts for working with txt files containing document boundaries

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/736

Differential Revision: D15314626

Pulled By: myleott

fbshipit-source-id: 1e0c32529afee57e43fe5d6c7991cd13eb8a52c4
This commit is contained in:
Myle Ott 2019-05-12 16:28:51 -07:00 committed by Facebook Github Bot
parent 43722c5e2b
commit 287d31e210
3 changed files with 199 additions and 0 deletions

61
scripts/count_docs.py Normal file
View File

@ -0,0 +1,61 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Count the number of documents and average number of lines and tokens per
document in a large file. Documents should be separated by a single empty line.
"""
import argparse
import gzip
import random
import sys
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument('input')
parser.add_argument('--gzip', action='store_true')
args = parser.parse_args()
def gopen():
if args.gzip:
return gzip.open(args.input, 'r')
else:
return open(args.input, 'r', encoding='utf-8')
num_lines = []
num_toks = []
with gopen() as h:
num_docs = 1
num_lines_in_doc = 0
num_toks_in_doc = 0
for i, line in enumerate(h):
if len(line.strip()) == 0: # empty line indicates new document
num_docs += 1
num_lines.append(num_lines_in_doc)
num_toks.append(num_toks_in_doc)
num_lines_in_doc = 0
num_toks_in_doc = 0
else:
num_lines_in_doc += 1
num_toks_in_doc += len(line.rstrip().split())
if i % 1000000 == 0:
print(i, file=sys.stderr, end="", flush=True)
elif i % 100000 == 0:
print(".", file=sys.stderr, end="", flush=True)
print(file=sys.stderr, flush=True)
print("found {} docs".format(num_docs))
print("average num lines per doc: {}".format(np.mean(num_lines)))
print("average num toks per doc: {}".format(np.mean(num_toks)))
if __name__ == '__main__':
main()

55
scripts/shard_docs.py Normal file
View File

@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Split a large file into shards while respecting document boundaries. Documents
should be separated by a single empty line.
"""
import argparse
import contextlib
import random
import sys
def main():
parser = argparse.ArgumentParser()
parser.add_argument('input')
parser.add_argument('--num-shards', type=int)
args = parser.parse_args()
assert args.num_shards is not None and args.num_shards > 1
with open(args.input, 'r', encoding='utf-8') as h:
with contextlib.ExitStack() as stack:
outputs = [
stack.enter_context(open(args.input + ".shard" + str(i), "w", encoding="utf-8"))
for i in range(args.num_shards)
]
doc = []
first_doc = [True]*args.num_shards
def output_doc(i):
if not first_doc[i]:
outputs[i].write("\n")
first_doc[i] = False
for line in doc:
outputs[i].write(line)
doc.clear()
num_docs = 0
for line in h:
if line.strip() == "": # empty line indicates new document
output_doc(num_docs % args.num_shards)
num_docs += 1
else:
doc.append(line)
output_doc(num_docs % args.num_shards)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Split a large file into a train and valid set while respecting document
boundaries. Documents should be separated by a single empty line.
"""
import argparse
import random
import sys
def main():
parser = argparse.ArgumentParser()
parser.add_argument('input')
parser.add_argument('sample_output', help='train output file')
parser.add_argument('remainder_output', help='valid output file')
parser.add_argument('-k', type=int, help="remainder size")
args = parser.parse_args()
assert args.k is not None
sample = []
remainder = []
num_docs = [0]
def update_sample(doc):
if len(sample) < args.k:
sample.append(doc.copy())
else:
i = num_docs[0]
j = random.randrange(i + 1)
if j < args.k:
remainder.append(sample[j])
sample[j] = doc.copy()
else:
remainder.append(doc.copy())
num_docs[0] += 1
doc.clear()
with open(args.input, 'r', encoding='utf-8') as h:
doc = []
for i, line in enumerate(h):
if line.strip() == "": # empty line indicates new document
update_sample(doc)
else:
doc.append(line)
if i % 1000000 == 0:
print(i, file=sys.stderr, end="", flush=True)
elif i % 100000 == 0:
print(".", file=sys.stderr, end="", flush=True)
if len(doc) > 0:
update_sample(doc)
print(file=sys.stderr, flush=True)
assert len(sample) == args.k
with open(args.sample_output, 'w', encoding='utf-8') as out:
first = True
for doc in sample:
if not first:
out.write("\n")
first = False
for line in doc:
out.write(line)
with open(args.remainder_output, 'w', encoding='utf-8') as out:
first = True
for doc in remainder:
if not first:
out.write("\n")
first = False
for line in doc:
out.write(line)
if __name__ == '__main__':
main()