modules.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def __call__(self, s_embed, s_src_pwr, s_mix_pwr, s_embed_flat=None):
        if s_embed_flat is None:
            s_embed_flat = tf.reshape(
                s_embed,
                [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE])
        with tf.variable_scope(self.name):
            s_wgt = tf.reshape(
                s_mix_pwr, [hparams.BATCH_SIZE, -1, 1])
            s_src_assignment = tf.argmax(s_src_pwr, axis=1)
            s_indices = tf.reshape(
                s_src_assignment,
                [hparams.BATCH_SIZE, -1])
            fn_segmean = lambda _: tf.unsorted_segment_sum(
                _[0], _[1], hparams.MAX_N_SIGNAL)
            s_attractors = tf.map_fn(fn_segmean, (
                s_embed_flat * s_wgt, s_indices),
                hparams.FLOATX)
            s_attractors_wgt = tf.map_fn(fn_segmean, (
                s_wgt, s_indices),
                hparams.FLOATX)
            s_attractors /= (s_attractors_wgt + hparams.EPS)

        if hparams.DEBUG:
            self.debug_fetches = dict()
        # float[B, C, E]
        return s_attractors
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号