GANcheckpoints.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Neural-Photo-Editor 作者: ajbrock 项目源码 文件源码
def load_weights(fname, params):
    # params = lasagne.layers.get_all_params(l_out,trainable=True)+[log_sigma]+[x for x in lasagne.layers.get_all_params(l_out) if x.name[-4:]=='mean' or x.name[-7:]=='inv_std']
    names = [ par.name for par in params ]
    if len(names)!=len(set(names)):
        raise ValueError('need unique param names')

    param_dict = np.load(fname)
    for param in params:
        if param.name in param_dict:
            stored_shape = np.asarray(param_dict[param.name].shape)
            param_shape = np.asarray(param.get_value().shape)
            if not np.all(stored_shape == param_shape):
                warn_msg = 'shape mismatch:'
                warn_msg += '{} stored:{} new:{}'.format(param.name, stored_shape, param_shape)
                warn_msg += ', skipping'
                warnings.warn(warn_msg)
            else:
                param.set_value(param_dict[param.name])
        else:
            logging.warn('unable to load parameter {} from {}'.format(param.name, fname))
    if 'metadata' in param_dict:
        metadata = pickle.loads(str(param_dict['metadata']))
    else:
        metadata = {}
    return metadata
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号