def read(x,x_hat,h_dec_prev):
"""Function to implement eq 27"""
Fx,Fy,gamma=attn_window("read",h_dec_prev,patch_read)
# gamma in [batch_size,1]
# Fx in [batch_size, patch_read, 28]
def filter_img(img,Fx,Fy,gamma,N):
Fxt=tf.transpose(Fx,perm=[0,2,1])
img=tf.reshape(img,[-1,B,A]) # in [batch_size, 28,28]
glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt)) #in [batch_size, patch_read, patch_read]
glimpse=tf.reshape(glimpse,[-1,N*N]) # in batch_size, patch_read*patch_read
return glimpse*tf.reshape(gamma,[-1,1])
x=filter_img(x,Fx,Fy,gamma,patch_read) # batch x (patch_read*patch_read)
x_hat=filter_img(x_hat,Fx,Fy,gamma,patch_read)
# x in [batch_size, patch_read^2]
# x_hat in [batch_size, patch_read^2]
return tf.concat(1,[x,x_hat]) # concat along feature axis
评论列表
文章目录