def pixelnet_convs(inputs, num_class, is_training=True, reuse=False):
num_batch = tf.shape(inputs)[0]
height = tf.shape(inputs)[1]
width = tf.shape(inputs)[2]
with tf.variable_scope('vgg_16', reuse=reuse):
net, hyperfeats = nets.vgg_like(inputs)
tf.add_to_collection('last_conv', net)
with tf.name_scope('hyper_columns'):
if is_training:
# sample pixels corresponding to the last feature elements
h, w = net.get_shape().as_list()[1:3]
trace_locations = ops.trace_locations_backward
else:
# sample pixels corresponding to the whole image
h, w = [height, width]
trace_locations = ops.trace_locations_forward
X, Y = tf.meshgrid(tf.range(w), tf.range(h), indexing='xy')
loc_x = tf.tile(tf.reshape(X, [1,-1]), [num_batch, 1])
loc_y = tf.tile(tf.reshape(Y, [1,-1]), [num_batch, 1])
locations = [trace_locations(loc_x, loc_y, [h, w], [tf.shape(feat)[1], tf.shape(feat)[2]])
for feat in hyperfeats]
net = ops.extract_values(hyperfeats, locations)
hyperchannels = net.get_shape().as_list()[-1]
net = tf.reshape(net, [num_batch, h, w, hyperchannels])
tf.add_to_collection('hyper_column', net)
return net
评论列表
文章目录