lstm.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:yaset 作者: jtourille 项目源码 文件源码
def loss_crf(self):
        """
        CRF based loss.
        :return: loss
        """

        # Reshaping seq_len tensor [seq_len, 1]
        seq_length_reshaped = tf.reshape(self.x_tokens_len, [tf.shape(self.x_tokens_len)[0], -1])

        # Computing loss by scanning mini-batch tensor
        out = tf.scan(self.loss_crf_scan, [self.prediction,
                                           seq_length_reshaped,
                                           self.y], back_prop=True, infer_shape=True, initializer=0.0)

        # Division by batch_size
        loss_crf = tf.divide(tf.reduce_sum(out), tf.cast(tf.shape(self.x_tokens)[0], dtype=tf.float32))

        return loss_crf
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号