test_ntm_rotate.py 文件源码

python
阅读 90 收藏 0 点赞 0 评论 0

项目:Neural-Turing-Machine 作者: yeoedward 项目源码 文件源码
def conv(v, k):
  """Computes circular convolution.
  Args:
      v: a 1-D `Tensor` (vector)
      k: a 1-D `Tensor` (kernel)
  """
  size = int(v.get_shape()[0])
  kernel_size = int(k.get_shape()[0])
  kernel_shift = int(math.floor(kernel_size/2.0))

  def loop(idx):
      if idx < 0: return size + idx
      if idx >= size : return idx - size
      else: return idx

  kernels = []
  for i in xrange(size):
      indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)]
      v_ = tf.gather(v, indices)
      kernels.append(tf.reduce_sum(v_ * k, 0))

  # # code with double loop
  # for i in xrange(size):
  #     for j in xrange(kernel_size):
  #         idx = i + kernel_shift - j + 1
  #         if idx < 0: idx = idx + size
  #         if idx >= size: idx = idx - size
  #         w = tf.gather(v, int(idx)) * tf.gather(kernel, j)
  #         output = tf.scatter_add(output, [i], tf.reshape(w, [1, -1]))

  return tf.pack(kernels)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号