parser.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:arc-swift 作者: qipeng 项目源码 文件源码
def ASw_transition_loss_pred(self, i, j, combined_head, combined_dep, transition_logit, SHIFT):
        # extract relevant portions of params
        rel_trans_feat_ids = self.trans_feat_ids[i*self.args.beam_size+j] if not self.train else self.trans_feat_ids[i, j]
        rel_trans_feat_size = self.trans_feat_sizes[i*self.args.beam_size+j] if not self.train else self.trans_feat_sizes[i, j]

        # core computations
        has_shift = tf.cond(tf.equal(rel_trans_feat_ids[0, 0], SHIFT), lambda: tf.constant(1), lambda: tf.constant(0))
        arc_trans_count = rel_trans_feat_size - has_shift

        arc_trans_feat_ids = tf.gather(rel_trans_feat_ids, tf.range(has_shift, rel_trans_feat_size))
        rel_head = tf.reshape(tf.gather(combined_head, arc_trans_feat_ids[:, 1]), [arc_trans_count, self.args.rel_emb_dim])
        rel_dep  = tf.reshape(tf.gather(combined_dep,  arc_trans_feat_ids[:, 2]), [arc_trans_count, self.args.rel_emb_dim])

        rel_hid = self.rel_merge(rel_head, rel_dep)
        rel_logit = self.rel_dense(rel_hid)
        arc_logit = tf.reshape(rel_logit, [-1])

        def logaddexp(a, b):
            mx = tf.maximum(a, b)
            return tf.log(tf.exp(a-mx) + tf.exp(b-mx)) + mx

        if self.train:
            # compute a loss and return it
            log_partition = tf.reduce_logsumexp(arc_logit)
            log_partition = tf.cond(tf.greater(has_shift, 0),
                    lambda: logaddexp(log_partition, transition_logit[rel_trans_feat_ids[0, 3]]),
                    lambda: log_partition)
            arc_logit = log_partition - arc_logit

            res = tf.cond(tf.greater(has_shift, 0),
                        lambda: tf.cond(tf.greater(self.trans_labels[i, j], 0),
                            lambda: arc_logit[self.trans_labels[i, j]-1],
                            lambda: log_partition - transition_logit[rel_trans_feat_ids[0, 3]]),
                        lambda: arc_logit[self.trans_labels[i, j]])

            return res
        else:
            # just return predictions
            arc_logit = tf.reshape(rel_logit, [-1])
            log_partition = tf.reduce_logsumexp(arc_logit)
            log_partition = tf.cond(tf.greater(has_shift, 0),
                    lambda: logaddexp(log_partition, transition_logit[rel_trans_feat_ids[0, 3]]),
                    lambda: log_partition)
            arc_logit = log_partition - arc_logit

            arc_pred = tf.cond(tf.greater(has_shift, 0),
                lambda: tf.concat([tf.reshape(log_partition - transition_logit[rel_trans_feat_ids[0, 3]], (-1,1)),
                         tf.reshape(arc_logit, (-1,1))], 0),
                lambda: tf.reshape(arc_logit, (-1, 1)))

            # correct shape
            current_output_shape = has_shift + arc_trans_count * rel_logit.get_shape()[1]
            arc_pred = tf.concat([arc_pred, 1e20 * tf.ones((tf.subtract(self.pred_output_size, current_output_shape), 1), dtype=tf.float32)], 0)
            arc_pred = tf.reshape(arc_pred, [-1])

            return arc_pred
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号