def describe_style(self, style_image, eval_out=False, pool_type='avg', last_layer='conv5_4'):
""" Runs the 'style_image' through the vgg network and extracts a statistical
description of the activations at convolution layers
Args:
style_image (PIL image object): displays the style to be transferred
eval_out (bool): wether to open tf session and eval style description to np array
pool_type (str): 'avg', 'max', or 'none', type of pooling to use
last_layer (str): vgg network will process image up to this layer
"""
with self.graph.as_default():
self.style_desc = {}
self.style_arr = tf.constant((np.expand_dims(style_image,0)[:,:,:,:3])
.astype('float32'))
x = self.style_arr-self.mean_pixel
self.stop = self.all_layers.index(last_layer)+1
for i, layer in enumerate(self.all_layers[:self.stop]):
if layer[:2] == 're': x = tf.nn.relu(x)
elif layer[:2] == 'po': x = self.pool_func(x, pool_type)
elif layer[:2] == 'co':
kernel = self.vgg_ph[layer+'_kernel']
bias = self.vgg_ph[layer+'_bias']
x = tf.nn.bias_add(tf.nn.conv2d(x, kernel,
strides=(1, 1, 1, 1),
padding='SAME'),bias)
layer_shape = tf.shape(x, out_type=tf.int32)
#flattens image tensor to (#pixels x #channels) assumes batch=1
#treats each pixel as an observation of Gaussian random vector
#in R^(#channels) and infers parameters
stl_activs = tf.reshape(x, [layer_shape[1]*layer_shape[2], layer_shape[3]])
mean_stl_activs = tf.reduce_mean(stl_activs, axis=0, keep_dims=True)
covar_stl_activs = (tf.matmul(stl_activs - mean_stl_activs,
stl_activs - mean_stl_activs, transpose_a=True)/
tf.cast(layer_shape[1]*layer_shape[2], tf.float32))
#takes root of covar_stl_activs
#(necessary for wdist, as tf cannot take eig of non-symmetric matrices)
eigvals,eigvects = tf.self_adjoint_eig(covar_stl_activs)
eigval_mat = tf.diag(tf.sqrt(tf.maximum(eigvals,0.)))
root_covar_stl_activs = tf.matmul(tf.matmul(eigvects, eigval_mat)
,eigvects,transpose_b=True)
trace_covar_stl = tf.reduce_sum(tf.maximum(eigvals,0))
self.style_desc[layer] = (mean_stl_activs,
trace_covar_stl,
root_covar_stl_activs)
if eval_out==True:
with tf.Session(graph=self.graph, config=self.config) as sess:
self.style_desc = sess.run(self.style_desc, feed_dict=self.feed_dict)
评论列表
文章目录