mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-11 09:54:22 +03:00
averaging models scripts
This commit is contained in:
parent
b6e9b94ec0
commit
c3aed92595
20
scripts/average.py
Executable file
20
scripts/average.py
Executable file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
import numpy as np;
|
||||
|
||||
average = dict()
|
||||
|
||||
for filename in sys.argv[1:-1]:
|
||||
with open(filename, "rb") as mfile:
|
||||
m = np.load(mfile)
|
||||
for k in m:
|
||||
if k not in average:
|
||||
average[k] = m[k]
|
||||
elif average[k].shape == m[k].shape:
|
||||
average[k] += m[k]
|
||||
|
||||
for k in average:
|
||||
average[k] /= len(sys.argv[1:-1])
|
||||
|
||||
np.savez(sys.argv[-1], **average)
|
@ -24,8 +24,6 @@ int main(int argc, char* argv[]) {
|
||||
God::Init(argc, argv);
|
||||
boost::timer::cpu_timer timer;
|
||||
|
||||
LOG(info) << "Reading input";
|
||||
|
||||
std::string in;
|
||||
std::size_t taskCounter = 0;
|
||||
|
||||
@ -34,6 +32,8 @@ int main(int argc, char* argv[]) {
|
||||
LOG(info) << "Setting number of threads to " << threadCount;
|
||||
ThreadPool pool(threadCount);
|
||||
std::vector<std::future<History>> results;
|
||||
|
||||
LOG(info) << "Reading input";
|
||||
while(std::getline(std::cin, in)) {
|
||||
|
||||
results.emplace_back(
|
||||
|
Loading…
Reference in New Issue
Block a user