def circular_convolution(v, k, size):
"""Computes circular convolution.
Args:
v: a 1-D `Tensor` (vector)
k: a 1-D `Tensor` (kernel)
size: a int scalar indicating size of the kernel k
"""
kernel_size = int(k.get_shape()[1])
kernel_shift = int(math.floor(kernel_size/2.0))
v_list = tf.split(0, size, v)
def loop(idx):
if idx < 0: return size + idx
if idx >= size : return idx - size
else: return idx
kernels = []
for i in xrange(size):
indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)]
#v_ = tf.gather(v, indices)
v_sublist = [v_list[indices[j]] for j in range(len(indices))]
v_ = tf.concat(0, v_sublist)
kernels.append(tf.reduce_sum(v_ * tf.transpose(k), 0, keep_dims=True))
return tf.concat(0, kernels)
#return tf.dynamic_stitch([i for i in xrange(size)], kernels)
评论列表
文章目录