hm.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:holographic_memory 作者: jramapuram 项目源码 文件源码
def fft_circ_conv1d(X, keys, batch_size, num_copies, num_keys=None, conj=False):
        if conj:
            keys = HolographicMemory.conj_real_by_complex(keys)

        # Get our original shapes
        xshp = X.get_shape().as_list()
        kshp = keys.get_shape().as_list()
        kshp[0] = num_keys if num_keys is not None else kshp[0]
        kshp[1] = xshp[1] if kshp[1] is None else kshp[1]
        print 'X : ', xshp, ' | keys : ', kshp, ' | batch_size = ', batch_size

        # duplicate out input data by the ratio: number_keys / batch_size
        # eg: |input| = [2, 784] ; |keys| = 3*[2, 784] ; (3 is the num_copies)
        #     |new_input| = 6/2 |input| = [input; input; input]
        #
        # At test: |memories| = [3, 784] ; |keys| = 3*[n, 784] ;
        #          |new_input| = 3n / 3 = n   [where n is the number of desired parallel retrievals]
        num_dupes = kshp[0] / batch_size
        print 'num dupes = ', num_dupes
        xcplx = HolographicMemory.split_to_complex(tf.tile(X, [num_dupes, 1]) \
                                                   if num_dupes > 1 else X)
        xshp = xcplx.get_shape().as_list()
        kcplx = HolographicMemory.split_to_complex(keys, kshp)

        # Convolve & re-cast to a real valued function
        unsplit_func = HolographicMemory.unsplit_from_complex_ri if not conj \
                       else HolographicMemory.unsplit_from_complex_ir
        #fft_mul = HolographicMemory.bound(tf.mul(tf.fft(xcplx), tf.fft(kcplx)))
        fft_mul = tf.mul(tf.fft(xcplx), tf.fft(kcplx))
        conv = unsplit_func(tf.ifft(fft_mul))
        print 'full conv = ', conv.get_shape().as_list()


        batch_iter = min(batch_size, xshp[0]) if xshp[0] is not None else batch_size
        print 'batch = ', batch_size, ' | num_copies = ', num_copies, '| num_keys = ', num_keys, \
            '| xshp[0] = ', xshp[0], ' | len(keys) = ', kshp[0], ' | batch iter = ', batch_iter
        conv_concat = [tf.expand_dims(tf.reduce_mean(conv[begin:end], 0), 0)
                       for begin, end in zip(range(0, kshp[0], batch_iter),
                                             range(batch_iter, kshp[0]+1, batch_iter))]
        print 'conv concat = ', len(conv_concat), ' x ', conv_concat[0].get_shape().as_list()

        # return a single concatenated  tensor:
        # C = [c0; c1; ...]
        C = tf.concat(0, conv_concat)

        return C
        #C = tf_mean_std_normalize(C)
        #return C / tf.maximum(tf.reduce_max(C), 1e-20)
        #return tf.nn.sigmoid(C)
        #return tf_mean_std_normalize(C)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号