def __init__(self, content, style, content_names, style_names):
"""
Suppose the content and style is a numpy array,
"""
self.content_names = content_names
self.style_names = style_names
self.VGG_MEAN = [123.68, 116.78, 103.94]
tf.reset_default_graph()
content = tf.constant(content) - tf.reshape(tf.constant(self.VGG_MEAN), [1, 1, 3])
_, self.content_layers = nets.vgg.vgg_19(tf.expand_dims(content, axis = 0), is_training = False, spatial_squeeze = False)
layer_name, layer_value = zip(*filter(lambda x: x[0] in content_names, self.content_layers.items()))
init_fn = slim.assign_from_checkpoint_fn("./vgg_19.ckpt", slim.get_variables_to_restore())
with tf.Session() as s, tf.device("/device:XLA_CPU:0"):
init_fn(s)
layer_value = s.run(layer_value)
self.content_map = dict(zip(layer_name, layer_value))
#print(content_map)
tf.reset_default_graph()
style = tf.constant(style) - tf.reshape(tf.constant(self.VGG_MEAN), [1, 1, 3])
_, self.style_layers = nets.vgg.vgg_19(tf.expand_dims(style, axis = 0), is_training = False, spatial_squeeze = False)
layer_name, layer_value = zip(*filter(lambda x: x[0] in style_names, self.style_layers.items()))
init_fn = slim.assign_from_checkpoint_fn("./vgg_19.ckpt", slim.get_variables_to_restore())
with tf.Session() as s, tf.device("/device:XLA_CPU:0"):
init_fn(s)
layer_value = s.run(layer_value)
self.style_map = dict(zip(layer_name, layer_value))
#print(content_map)
tf.reset_default_graph()
self.target = tf.Variable(np.random.randint(0, 256, content.shape), dtype = tf.float32, name = "generate_image")
self._build_graph()
评论列表
文章目录