def positive_conv(a, b):
"""Pairwise convolution on the positive domain of batches of 1-d vectors.
Args:
a: discrete function on the positive domain (e.g. real-valued vector
with a[0] = f(0), etc). Shape of [batch_size, domain_size].
b: same as a.
Returns:
Discrete function on positive domain representing convolution of a and b.
"""
batch_size = a.get_shape().dims[0].value
width = a.get_shape().dims[1].value
a = tf.pad(a, [[0, 0], [width, 0]])
a = tf.transpose(a)
b = tf.pad(b, [[0, 0], [width, 0]])
b = tf.reverse(b, [False, True])
b = tf.transpose(b)
reshaped_a = tf.reshape(a, [1, 1, width * 2, batch_size])
reshaped_b = tf.reshape(b, [1, width * 2, batch_size, 1])
res = tf.nn.depthwise_conv2d(
reshaped_a, reshaped_b, strides=[1, 1, 1, 1], padding="SAME")
res = tf.reshape(tf.transpose(res), [batch_size, width * 2])
res = tf.slice(res, [0, width], [batch_size, width])
return res
fft_tree_indep_inference.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录