def get_sequence_loss(logits, targets, weights, softmax_loss_function=None, per_example_loss=False):
if per_example_loss:
assert len(logits) == len(targets)
# We need to make target and int64-tensor and set its shape.
bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
crossent = sequence_loss_by_example(logits, bucket_target, weights,
softmax_loss_function=softmax_loss_function)
else:
assert len(logits) == len(targets)
bucket_target = [array_ops.reshape(math_ops.to_int64(x), [-1]) for x in targets]
crossent = sequence_loss_by_batch(logits, bucket_target, weights,
softmax_loss_function=softmax_loss_function)
return crossent
评论列表
文章目录