def save_weights(fname, params, metadata=None):
""" assumes all params have unique names.
"""
# Includes batchnorm params now
names = [par.name for par in params]
if len(names) != len(set(names)):
raise ValueError('need unique param names')
param_dict = { param.name : param.get_value(borrow=False)
for param in params }
if metadata is not None:
param_dict['metadata'] = pickle.dumps(metadata)
logging.info('saving {} parameters to {}'.format(len(params), fname))
# try to avoid half-written files
fname = Path(fname)
if fname.exists():
tmp_fname = Path(fname.stripext() + '.tmp.npz') # TODO yes, this is a hack
np.savez_compressed(str(tmp_fname), **param_dict)
tmp_fname.rename(fname)
else:
np.savez_compressed(str(fname), **param_dict)
评论列表
文章目录