def extract_patches(inputs, size, offsets):
batch_size = inputs.get_shape()[0]
padded = tf.pad(inputs, [[0,0],[2,2],[2,2],[0,0]])
unpacked = tf.unpack(tf.squeeze(padded))
extra_margins = tf.constant([1,1,2,2])
sliced_list = []
for i in xrange(batch_size.value):
margins = tf.random_shuffle(extra_margins)
margins = margins[:2]
start_pts = tf.sub(offsets[i,:],margins)
sliced = tf.slice(unpacked[i],start_pts,size)
sliced_list.append(sliced)
patches = tf.pack(sliced_list)
patches = tf.expand_dims(patches,3)
return patches
评论列表
文章目录