def extract_glimpse(inpt, attention_params, glimpse_size):
"""Extracts an attention glimpse
:param inpt: tensor of shape == (batch_size, img_height, img_width)
:param attention_params: tensor of shape = (batch_size, 6) as
[uy, sy, dy, ux, sx, dx] with u - mean, s - std, d - stride"
:param glimpse_size: 2-tuple of ints as (height, width),
size of the extracted glimpse
:return: tensor
"""
ap = attention_params
shape = inpt.get_shape()
rank = len(shape)
assert rank in (3, 4), "Input must be 3 or 4 dimensional tensor"
inpt_H, inpt_W = shape[1:3]
if rank == 3:
inpt = inpt[..., tf.newaxis]
rank += 1
Fy = gaussian_mask(ap[..., 0::2], glimpse_size[0], inpt_H)
Fx = gaussian_mask(ap[..., 1::2], glimpse_size[1], inpt_W)
gs = []
for channel in tf.unstack(inpt, axis=rank - 1):
g = tf.matmul(tf.matmul(Fy, channel, adjoint_a=True), Fx)
gs.append(g)
g = tf.stack(gs, axis=rank - 1)
g.set_shape([shape[0]] + list(glimpse_size))
return g
评论列表
文章目录