cross_entropy_direct.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:master-thesis 作者: AndreasMadsen 项目源码 文件源码
def cross_entropy_direct(logits, target, name):
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                          labels=target)
    loss *= tf.cast(tf.not_equal(target, tf.zeros_like(target)), tf.float32)

    batch_loss = tf.reduce_sum(loss, axis=1)
    batch_loss = tf.reduce_mean(batch_loss, axis=0)
    batch_loss = tf.check_numerics(batch_loss, f'check/cross_entropy/{name}')

    return batch_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号