def extract_patch(x, f_y, f_x, nchannels, normalize=False):
"""
Args:
x: [B, H, W, D]
f_y: [B, H, FH]
f_x: [B, W, FH]
nchannels: D
Returns:
patch: [B, FH, FW]
"""
patch = [None] * nchannels
fsize_h = tf.shape(f_y)[2]
fsize_w = tf.shape(f_x)[2]
hh = tf.shape(x)[1]
ww = tf.shape(x)[2]
for dd in range(nchannels):
# [B, H, W]
x_ch = tf.reshape(
tf.slice(x, [0, 0, 0, dd], [-1, -1, -1, 1]), tf.pack([-1, hh, ww]))
patch[dd] = tf.reshape(
tf.batch_matmul(
tf.batch_matmul(
f_y, x_ch, adj_x=True), f_x),
tf.pack([-1, fsize_h, fsize_w, 1]))
return tf.concat(3, patch)
评论列表
文章目录