def read_attention(self, x, x_hat, h_dec_prev):
Fx, Fy, gamma = self.attn_window("read", h_dec_prev)
# we have the parameters for a patch of gaussian filters. apply them.
def filter_img(img, Fx, Fy, gamma):
Fxt = tf.transpose(Fx, perm=[0,2,1])
img = tf.reshape(img, [-1, self.img_size, self.img_size])
# Apply the gaussian patches:
# keep in mind: horiz = imgsize = verts (they are all the image size)
# keep in mind: attn = height/length of attention patches
# allfilters = [attn, vert] * [imgsize,imgsize] * [horiz, attn]
# we have batches, so the full batch_matmul equation looks like:
# [1, 1, vert] * [batchsize,imgsize,imgsize] * [1, horiz, 1]
glimpse = tf.batch_matmul(Fy, tf.batch_matmul(img, Fxt))
glimpse = tf.reshape(glimpse, [-1, self.attention_n**2])
# finally scale this glimpse w/ the gamma parameter
return glimpse * tf.reshape(gamma, [-1, 1])
x = filter_img(x, Fx, Fy, gamma)
x_hat = filter_img(x_hat, Fx, Fy, gamma)
return tf.concat(1, [x, x_hat])
# encode an attention patch
评论列表
文章目录