parse_model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def __init__(self, simple_scorer, attention_scorer, soft_copy_scorer):
        """

        Args:
            simple_scorer (SimplePredicateScorer)
            attention_scorer (AttentionPredicateScorer)
            soft_copy_scorer (SoftCopyPredicateScorer)
        """
        assert isinstance(simple_scorer, SimplePredicateScorer)
        assert isinstance(attention_scorer, AttentionPredicateScorer)
        assert isinstance(soft_copy_scorer, SoftCopyPredicateScorer)

        simple_scores = simple_scorer.scores  # (batch_size, num_candidates)
        attention_scores = attention_scorer.scores  # (batch_size, num_candidates)
        soft_copy_scores = soft_copy_scorer.scores  # (batch_size, num_candidates)

        # check that Tensors are finite
        def verify_finite_inside_mask(scores, msg):
            finite_scores = scores.with_pad_value(0).values
            assert_op = tf.verify_tensor_all_finite(finite_scores, msg)
            return assert_op

        with tf.control_dependencies([
            verify_finite_inside_mask(simple_scores, 'simple_scores'),
            verify_finite_inside_mask(attention_scores, 'attention_scores'),
            verify_finite_inside_mask(soft_copy_scores, 'soft copy scores'),
        ]):
            scores = SequenceBatch(
                simple_scores.values + attention_scores.values + soft_copy_scores.values,
                simple_scores.mask)
            subscores = SequenceBatch(
                tf.pack(
                    [simple_scores.values, attention_scores.values, soft_copy_scores.values],
                    axis=2),
                simple_scores.mask)

        scores = scores.with_pad_value(-float('inf'))
        probs = SequenceBatch(tf.nn.softmax(scores.values), scores.mask)

        self._scores = scores
        self._subscores = subscores
        self._probs = probs

        self._simple_scorer = simple_scorer
        self._attention_scorer = attention_scorer
        self._soft_copy_scorer = soft_copy_scorer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号