draw.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DRAW 作者: RobRomijnders 项目源码 文件源码
def read(x,x_hat,h_dec_prev):
  """Function to implement eq 27"""
  Fx,Fy,gamma=attn_window("read",h_dec_prev,patch_read)
  # gamma in [batch_size,1]
  # Fx in [batch_size, patch_read, 28]
  def filter_img(img,Fx,Fy,gamma,N):
    Fxt=tf.transpose(Fx,perm=[0,2,1])
    img=tf.reshape(img,[-1,B,A])  # in [batch_size, 28,28]
    glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt)) #in [batch_size, patch_read, patch_read]
    glimpse=tf.reshape(glimpse,[-1,N*N])     # in batch_size, patch_read*patch_read
    return glimpse*tf.reshape(gamma,[-1,1])
  x=filter_img(x,Fx,Fy,gamma,patch_read) # batch x (patch_read*patch_read)
  x_hat=filter_img(x_hat,Fx,Fy,gamma,patch_read)
  # x in [batch_size, patch_read^2]
  # x_hat in [batch_size, patch_read^2]
  return tf.concat(1,[x,x_hat]) # concat along feature axis
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号