def conv(v, k):
"""Computes circular convolution.
Args:
v: a 1-D `Tensor` (vector)
k: a 1-D `Tensor` (kernel)
"""
size = int(v.get_shape()[0])
kernel_size = int(k.get_shape()[0])
kernel_shift = int(math.floor(kernel_size/2.0))
def loop(idx):
if idx < 0: return size + idx
if idx >= size : return idx - size
else: return idx
kernels = []
for i in xrange(size):
indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)]
v_ = tf.gather(v, indices)
kernels.append(tf.reduce_sum(v_ * k, 0))
# # code with double loop
# for i in xrange(size):
# for j in xrange(kernel_size):
# idx = i + kernel_shift - j + 1
# if idx < 0: idx = idx + size
# if idx >= size: idx = idx - size
# w = tf.gather(v, int(idx)) * tf.gather(kernel, j)
# output = tf.scatter_add(output, [i], tf.reshape(w, [1, -1]))
return tf.pack(kernels)
test_ntm_rotate.py 文件源码
python
阅读 90
收藏 0
点赞 0
评论 0
评论列表
文章目录