def save_weights(fname, params, history=None):
param_dict = convert2dict(params)
logging.info('saving {} parameters to {}'.format(len(params), fname))
fname = Path(fname)
filename, ext = osp.splitext(fname)
history_file = osp.join(osp.dirname(fname), 'history.npy')
np.save(history_file, history)
logging.info("Save history to {}".format(history_file))
if ext == '.npy':
np.save(filename + '.npy', param_dict)
else:
f = gzip.open(fname, 'wb')
pickle.dump(param_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
f.close()
评论列表
文章目录