inception_v1.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Deep_Learning_In_Action 作者: SunnyMarkLiu 项目源码 文件源码
def load_pretrained_model(self):
        """
        Load the pretrained weights into the non-trainable layer
        :return:
        """
        print('Load the pretrained weights into the non-trainable layer...')
        from tensorflow.python.framework import ops
        trainable_variables = slim.get_variables(None, None, ops.GraphKeys.TRAINABLE_VARIABLES)

        reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_model_cpkt)
        pretrained_model_variables = reader.get_variable_to_shape_map()
        for variable in trainable_variables:
            variable_name = variable.name.split(':')[0]
            if variable_name in self.skip_layer:
                continue
            if variable_name not in pretrained_model_variables:
                continue
            print('load ' + variable_name)
            with tf.variable_scope('', reuse=True):
                var = tf.get_variable(variable_name, trainable=False)
                data = reader.get_tensor(variable_name)
                self.sess.run(var.assign(data))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号