def restore_inception_resnet_variables_from_weight(sess, weights_path):
adam_vars = [var for var in tf.global_variables()
if 'Adam' in var.name or
'beta1_power' in var.name or
'beta2_power' in var.name]
uninit_vars = tf.get_collection(
tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionResnetV2/Conv2d_1a_3x3') + adam_vars
init_op = tf.variables_initializer(uninit_vars)
variables_to_restore = slim.get_variables_to_restore(
exclude=['InceptionResnetV2/Conv2d_1a_3x3'])
for var in uninit_vars:
if var in variables_to_restore:
variables_to_restore.remove(var)
saver = tf.train.Saver(variables_to_restore)
print 'Initializing new variables to train from downloaded inception resnet weights'
sess.run(init_op)
saver.restore(sess, weights_path)
return 0
评论列表
文章目录