modules.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def __call__(self, s_embed, s_src_pwr=None, s_mix_pwr=None, s_embed_flat=None):
        with tf.variable_scope(self.name):
            v_anchors = tf.get_variable(
                'anchors', [hparams.NUM_ANCHOR, hparams.EMBED_SIZE],
                initializer=tf.random_normal_initializer(
                    stddev=1.))

            # all combinations of anchors
            s_anchor_sets = ops.combinations(
                v_anchors, hparams.MAX_N_SIGNAL)

            # equation (6)
            s_anchor_assignment = tf.einsum(
                'btfe,pce->bptfc',
                s_embed, s_anchor_sets)
            s_anchor_assignment = tf.nn.softmax(s_anchor_assignment)

            # equation (7)
            s_attractor_sets = tf.einsum(
                'bptfc,btfe->bpce',
                s_anchor_assignment, s_embed)
            s_attractor_sets /= tf.expand_dims(
                tf.reduce_sum(s_anchor_assignment, axis=(2,3)), -1)

            # equation (8)
            s_in_set_similarities = tf.reduce_max(
                tf.matmul(
                    s_attractor_sets,
                    tf.transpose(s_attractor_sets, [0, 1, 3, 2])),
                axis=(-1, -2))

            # equation (9)
            s_subset_choice = tf.argmin(s_in_set_similarities, axis=1)
            s_subset_choice = tf.transpose(tf.stack([
                tf.range(hparams.BATCH_SIZE, dtype=tf.int64),
                s_subset_choice]))
            s_attractors = tf.gather_nd(s_attractor_sets, s_subset_choice)

        if hparams.DEBUG:
            self.debug_fetches = dict(
                asets=s_attractor_sets,
                anchors=v_anchors,
                subset_choice=s_subset_choice)

        return s_attractors
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号