net_utils.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:tensorflow_yolo2 作者: wenxichen 项目源码 文件源码
def restore_inception_resnet_variables_from_weight(sess, weights_path):

    adam_vars = [var for var in tf.global_variables()
                 if 'Adam' in var.name or
                 'beta1_power' in var.name or
                 'beta2_power' in var.name]
    uninit_vars = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_1a_3x3') + adam_vars
    init_op = tf.variables_initializer(uninit_vars)

    variables_to_restore = slim.get_variables_to_restore(
        exclude=['InceptionResnetV2/Conv2d_1a_3x3'])
    for var in uninit_vars:
        if var in variables_to_restore:
            variables_to_restore.remove(var)
    saver = tf.train.Saver(variables_to_restore)

    print 'Initializing new variables to train from downloaded inception resnet weights'
    sess.run(init_op)
    saver.restore(sess, weights_path)

    return 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号