def extract_features(inputs, k_idxs, map_h):
"""Extract top k fine features
NOTE.
do not use tf.image.extract_glimpse ops to get input patches
(cf. https://github.com/tensorflow/tensorflow/issues/2134)
"""
def _extract_feature(inputs, idxs):
idxs = tf.expand_dims(idxs,1)
idx_i = tf.floordiv(idxs, map_h)
idx_j = tf.mod(idxs, map_h)
# NOTE: the below origins are starting points, not center!
origin_i = 2*(2*idx_i+1)+3 - 5 + 2
origin_j = 2*(2*idx_j+1)+3 - 5 + 2
origin_centers = tf.concat(1,[origin_i,origin_j])
# NOTE: size also depends on the architecture
#patches = tf.image.extract_glimpse(inputs, size=[14,14], offsets=origin_centers,
# centered=False, normalized=False)
patches = extract_patches(inputs, size=[14,14], offsets=origin_centers)
#fine_features = fine_layers(patches)
fine_features = []
src_idxs = tf.concat(1,[idx_i,idx_j])
return fine_features, src_idxs, patches
k_features = []
k_src_idxs = []
k_patches = []
for i in xrange(N_PATCHES):
fine_feature, src_idx, patches = _extract_feature(inputs,k_idxs[:,i])
k_features.append(fine_feature)
k_src_idxs.append(src_idx)
k_patches.append(patches)
concat_patches = tf.concat(0,k_patches)
concat_k_features = fine_layers(concat_patches)
k_features = tf.split(0,N_PATCHES,concat_k_features)
return k_features, k_src_idxs, k_patches
评论列表
文章目录