def replace_features(coarse_features, fine_features, replace_idxs):
""" Replace fine features with the corresponding coarse features
Trick.
use tf.dynamic_stitch ops
"""
# TODO: simplify indexing
def _convert_to_1d_idxs(src_idxs):
""" Convert 2D idxs to 1D idxs
within 1D tensor whose shape is (b*h*w*c)
"""
batch_idx_len = map_channel.value * map_width.value * map_height.value
batch_idx_base = [i*batch_idx_len for i in xrange(batch_size.value)]
batch_1d = map_channel.value * map_width.value * src_idxs[:,0] + \
map_channel.value * src_idxs[:,1]
batch_1d = tf.add(batch_1d,batch_idx_base)
flat_idxs = [batch_1d+i for i in xrange(map_channel.value)]
flat_idxs = tf.reshape(tf.transpose(tf.pack(flat_idxs)), [-1])
return flat_idxs
batch_size, map_height, map_width, map_channel = coarse_features.get_shape()
# flatten coarse features
flat_coarse_features = tf.reshape(coarse_features, [batch_size.value,-1])
flat_coarse_features = tf.reshape(flat_coarse_features, [-1])
# flatten fine features
flat_fine_features = [tf.reshape(i,[-1]) for i in fine_features]
flat_fine_features = tf.concat(0,flat_fine_features)
flat_fine_idxs = [_convert_to_1d_idxs(i) for i in replace_idxs]
flat_fine_idxs = tf.concat(0,flat_fine_idxs)
# extract coarse features to be replaced
# this is required for hint-based training
flat_coarse_replaced = tf.gather(flat_coarse_features, flat_fine_idxs, validate_indices=False)
merged = tf.dynamic_stitch([tf.range(0,flat_coarse_features.get_shape()[0]),flat_fine_idxs],
[flat_coarse_features,flat_fine_features])
merged = tf.reshape(merged,coarse_features.get_shape())
return merged, flat_coarse_replaced, flat_fine_features
评论列表
文章目录