modellib.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:rec-attend-public 作者: renmengye 项目源码 文件源码
def extract_patch(x, f_y, f_x, nchannels, normalize=False):
  """
  Args:
      x: [B, H, W, D]
      f_y: [B, H, FH]
      f_x: [B, W, FH]
      nchannels: D
  Returns:
      patch: [B, FH, FW]
  """
  patch = [None] * nchannels
  fsize_h = tf.shape(f_y)[2]
  fsize_w = tf.shape(f_x)[2]
  hh = tf.shape(x)[1]
  ww = tf.shape(x)[2]

  for dd in range(nchannels):
    # [B, H, W]
    x_ch = tf.reshape(
        tf.slice(x, [0, 0, 0, dd], [-1, -1, -1, 1]), tf.pack([-1, hh, ww]))
    patch[dd] = tf.reshape(
        tf.batch_matmul(
            tf.batch_matmul(
                f_y, x_ch, adj_x=True), f_x),
        tf.pack([-1, fsize_h, fsize_w, 1]))

  return tf.concat(3, patch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号