train.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:onsager_deep_learning 作者: mborgerding 项目源码 文件源码
def load_trainable_vars(sess,filename):
    """load a .npz archive and assign the value of each loaded
    ndarray to the trainable variable whose name matches the
    archive key.  Any elements in the archive that do not have
    a corresponding trainable variable will be returned in a dict.
    """
    other={}
    try:
        tv=dict([ (str(v.name),v) for v in tf.trainable_variables() ])
        for k,d in np.load(filename).items():
            if k in tv:
                print('restoring ' + k)
                sess.run(tf.assign( tv[k], d) )
            else:
                other[k] = d
    except IOError:
        pass
    return other
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号