checkpoints.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:triple-gan 作者: zhenxuan00 项目源码 文件源码
def save_weights(fname, params, history=None):
    param_dict = convert2dict(params)

    logging.info('saving {} parameters to {}'.format(len(params), fname))
    fname = Path(fname)

    filename, ext = osp.splitext(fname)
    history_file = osp.join(osp.dirname(fname), 'history.npy')
    np.save(history_file, history)
    logging.info("Save history to {}".format(history_file))
    if ext == '.npy':
        np.save(filename + '.npy', param_dict)
    else:
        f = gzip.open(fname, 'wb')
        pickle.dump(param_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
        f.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号