Add average_checkpoints.py script + unit test

Co-authored-by: theweiho <weiho@fb.com>
This commit is contained in:
Myle Ott 2018-03-13 10:15:26 -04:00
parent 7d19e36dc4
commit c777340aab
3 changed files with 160 additions and 0 deletions

0
scripts/__init__.py Normal file
View File

View File

@ -0,0 +1,88 @@
#!/usr/bin/env python3
import argparse
import collections
import torch
def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs: An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
for f in inputs:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state['model']
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
'For checkpoint {}, expected list of params: {}, '
'but found: {}'.format(f, params_keys, model_params_keys)
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
params_dict[k].append(model_params[k])
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
new_state['model'] = averaged_params
return new_state
def main():
parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
)
parser.add_argument(
'--inputs',
required=True,
nargs='+',
help='Input checkpoint file paths.',
)
parser.add_argument(
'--output',
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
)
args = parser.parse_args()
print(args)
new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output)
print('Finished writing averaged checkpoint to {}.'.format(args.output))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,72 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import collections
import os
import tempfile
import unittest
import numpy as np
import torch
from scripts.average_checkpoints import average_checkpoints
class TestAverageCheckpoints(unittest.TestCase):
def test_average_checkpoints(self):
params_0 = collections.OrderedDict(
[
('a', torch.DoubleTensor([100.0])),
('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
('c', torch.IntTensor([7, 8, 9])),
]
)
params_1 = collections.OrderedDict(
[
('a', torch.DoubleTensor([1.0])),
('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
('c', torch.IntTensor([2, 2, 2])),
]
)
params_avg = collections.OrderedDict(
[
('a', torch.DoubleTensor([50.5])),
('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
# We expect truncation for integer division
('c', torch.IntTensor([4, 5, 5])),
]
)
fd_0, path_0 = tempfile.mkstemp()
fd_1, path_1 = tempfile.mkstemp()
torch.save(collections.OrderedDict([('model', params_0)]), path_0)
torch.save(collections.OrderedDict([('model', params_1)]), path_1)
output = average_checkpoints([path_0, path_1])['model']
os.close(fd_0)
os.remove(path_0)
os.close(fd_1)
os.remove(path_1)
for (k_expected, v_expected), (k_out, v_out) in zip(
params_avg.items(), output.items()):
self.assertEqual(
k_expected, k_out, 'Key mismatch - expected {} but found {}. '
'(Expected list of keys: {} vs actual list of keys: {})'.format(
k_expected, k_out, params_avg.keys(), output.keys()
)
)
np.testing.assert_allclose(
v_expected.numpy(),
v_out.numpy(),
err_msg='Tensor value mismatch for key {}'.format(k_expected)
)
if __name__ == '__main__':
unittest.main()