losses.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:opinatt 作者: epochx 项目源码 文件源码
def get_sequence_loss(logits, targets, weights, softmax_loss_function=None, per_example_loss=False):

  if per_example_loss:
    assert len(logits) == len(targets)
    # We need to make target and int64-tensor and set its shape.
    bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
    crossent = sequence_loss_by_example(logits, bucket_target, weights,
                                              softmax_loss_function=softmax_loss_function)
  else:
    assert len(logits) == len(targets)
    bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
    crossent = sequence_loss_by_batch(logits, bucket_target, weights,
                                      softmax_loss_function=softmax_loss_function)

  return crossent
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号