def nearest_neighbor_patch(x, patch, patch_norm):
assert patch.data.shape[0] == 1, 'mini batch size of patch must be 1'
assert patch_norm.data.shape[0] == 1, 'mini batch size of patch_norm must be 1'
xp = cuda.get_array_module(x.data)
z = x.data
b, ch, h, w = z.shape
z = z.transpose((1, 0, 2, 3)).reshape((ch, -1))
norm = xp.expand_dims(xp.sum(z ** 2, axis=0) ** 0.5, 0)
z = z / xp.broadcast_to(norm, z.shape)
p = patch.data
p_norm = patch_norm.data
p = p.reshape((ch, -1))
p_norm = p_norm.reshape((1, -1))
p_normalized = p / xp.broadcast_to(p_norm, p.shape)
correlation = z.T.dot(p_normalized)
min_index = xp.argmax(correlation, axis=1)
nearest_neighbor = p.take(min_index, axis=1).reshape((ch, b, h, w)).transpose((1, 0, 2, 3))
return Variable(nearest_neighbor)
评论列表
文章目录