model_new.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:deeplearning 作者: zxjzxj9 项目源码 文件源码
def __init__(self, content, style, content_names, style_names):
        """
            Suppose the content and style is a numpy array,
        """

        self.content_names = content_names
        self.style_names = style_names
        self.VGG_MEAN = [123.68, 116.78, 103.94]

        tf.reset_default_graph()
        content = tf.constant(content) - tf.reshape(tf.constant(self.VGG_MEAN), [1, 1, 3])
        _, self.content_layers = nets.vgg.vgg_19(tf.expand_dims(content, axis = 0), is_training = False, spatial_squeeze = False)

        layer_name, layer_value = zip(*filter(lambda x: x[0] in content_names,  self.content_layers.items()))
        init_fn = slim.assign_from_checkpoint_fn("./vgg_19.ckpt", slim.get_variables_to_restore())
        with tf.Session() as s, tf.device("/device:XLA_CPU:0"):
            init_fn(s)
            layer_value = s.run(layer_value)

        self.content_map = dict(zip(layer_name, layer_value))
        #print(content_map)

        tf.reset_default_graph()
        style = tf.constant(style) - tf.reshape(tf.constant(self.VGG_MEAN), [1, 1, 3])
        _, self.style_layers = nets.vgg.vgg_19(tf.expand_dims(style, axis = 0), is_training = False, spatial_squeeze =  False)
        layer_name, layer_value = zip(*filter(lambda x: x[0] in style_names,  self.style_layers.items()))
        init_fn = slim.assign_from_checkpoint_fn("./vgg_19.ckpt", slim.get_variables_to_restore())

        with tf.Session() as s, tf.device("/device:XLA_CPU:0"):
            init_fn(s)
            layer_value = s.run(layer_value)

        self.style_map = dict(zip(layer_name, layer_value))
        #print(content_map)

        tf.reset_default_graph()
        self.target = tf.Variable(np.random.randint(0, 256, content.shape), dtype = tf.float32, name = "generate_image")
        self._build_graph()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号