def load_pretrained_model(self):
"""
Load the pretrained weights into the non-trainable layer
:return:
"""
print('Load the pretrained weights into the non-trainable layer...')
from tensorflow.python.framework import ops
trainable_variables = slim.get_variables(None, None, ops.GraphKeys.TRAINABLE_VARIABLES)
reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_model_cpkt)
pretrained_model_variables = reader.get_variable_to_shape_map()
for variable in trainable_variables:
variable_name = variable.name.split(':')[0]
if variable_name in self.skip_layer:
continue
if variable_name not in pretrained_model_variables:
continue
print('load ' + variable_name)
with tf.variable_scope('', reuse=True):
var = tf.get_variable(variable_name, trainable=False)
data = reader.get_tensor(variable_name)
self.sess.run(var.assign(data))
inception_v1.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录