parser.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:arc-swift 作者: qipeng 项目源码 文件源码
def pos_loss_pred(self, i, pos_embeddings, pos_logit, NUM_POS, gold_pos, pos_trainables):
        if self.args.no_pos:
            pos_emb = tf.nn.embedding_lookup(pos_embeddings, gold_pos[i])
            if self.train:
                return 0, pos_emb
            else:
                return tf.gather(gold_pos[i], tf.range(1, self.sent_length)), pos_emb
        else:
            pos_logit = pos_logit[1:]

            log_partition = tf.reduce_logsumexp(pos_logit, [1])

            pos_pred = tf.exp(pos_logit - tf.reshape(log_partition, (-1, 1)))
            pos_emb = tf.concat([tf.reshape(tf.nn.embedding_lookup(pos_embeddings, NUM_POS), (1, -1)),
                tf.matmul(pos_pred, pos_trainables)], 0)

            if self.train:
                loss = tf.reduce_sum(tf.gather(log_partition, tf.range(self.sent_lengths[i]-1))
                    - tf.gather(tf.reshape(pos_logit, [-1]),
                        tf.range(self.sent_lengths[i]-1) * NUM_POS
                        + tf.gather(gold_pos[i], tf.range(1, self.sent_lengths[i]))))

                return loss, pos_emb
            else:
                return tf.cast(tf.argmax(pos_pred, 1), tf.int32), pos_emb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号