def fix_variables(self, sess, pretrained_model):
print('Fix VGG16 layers..')
with tf.variable_scope('Fix_VGG16') as scope:
with tf.device("/cpu:0"):
# fix the vgg16 issue from conv weights to fc weights
# fix RGB to BGR
fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)
restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv,
self._scope + "/fc7/weights": fc7_conv,
self._scope + "/conv1/conv1_1/weights": conv1_rgb})
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv,
self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv,
self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'],
tf.reverse(conv1_rgb, [2])))
评论列表
文章目录