def average_models(best, L=6, model_dir='', model_name='ra.h5'):
print '... merging'
print '{} {:d}-{:d}'.format(model_dir, best-L/2, best+L/2)
params = {}
side_info = {}
attrs = {}
for i in xrange(max(best-L/2, 0), best+L/2):
with h5py.File(osp.join(model_dir, model_name+'.'+str(i)), 'r') as f:
for k, v in f.attrs.items():
attrs[k] = v
for p in f.keys():
if '#' not in p:
side_info[p] = f[p][...]
elif p in params:
params[p] += np.array(f[p]).astype('float32') / L
else:
params[p] = np.array(f[p]).astype('float32') / L
with h5py.File(osp.join(model_dir, model_name+'.merge'), 'w') as f:
for p in params.keys():
f[p] = params[p]
for s in side_info.keys():
f[s] = side_info[s]
for k, v in attrs.items():
f.attrs[k] = v
评论列表
文章目录