modules.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def __call__(self, s_embed, s_src_pwr, s_mix_pwr, s_embed_flat=None):
        if s_embed_flat is None:
            s_embed_flat = tf.reshape(
                s_embed,
                [hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE])
        with tf.variable_scope(self.name):
            s_src_assignment = tf.argmax(s_src_pwr, axis=1)
            s_indices = tf.reshape(
                s_src_assignment,
                [hparams.BATCH_SIZE, -1])
            fn_segmean = lambda _: tf.unsorted_segment_sum(
                _[0], _[1], hparams.MAX_N_SIGNAL)
            s_attractors = tf.map_fn(
                fn_segmean, (s_embed_flat, s_indices), hparams.FLOATX)
            s_attractors_wgt = tf.map_fn(
                fn_segmean, (tf.ones_like(s_embed_flat), s_indices),
                hparams.FLOATX)
            s_attractors /= (s_attractors_wgt + 1.)

        if hparams.DEBUG:
            self.debug_fetches = dict()
        # float[B, C, E]
        return s_attractors
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号