ops.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:mann-for-speech-separation 作者: KWTsou1220 项目源码 文件源码
def circular_convolution(v, k, size):
    """Computes circular convolution.
    Args:
        v: a 1-D `Tensor` (vector)
        k: a 1-D `Tensor` (kernel)
        size: a int scalar indicating size of the kernel k
    """
    kernel_size  = int(k.get_shape()[1])
    kernel_shift = int(math.floor(kernel_size/2.0))
    v_list = tf.split(0, size, v)

    def loop(idx):
        if idx < 0: return size + idx
        if idx >= size : return idx - size
        else: return idx

    kernels = []
    for i in xrange(size):
        indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)]
        #v_ = tf.gather(v, indices)
        v_sublist = [v_list[indices[j]] for j in range(len(indices))]
        v_        = tf.concat(0, v_sublist)
        kernels.append(tf.reduce_sum(v_ * tf.transpose(k), 0, keep_dims=True))

    return tf.concat(0, kernels)
    #return tf.dynamic_stitch([i for i in xrange(size)], kernels)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号