def sequence_loss_per_sample(logits,
targets,
weights):
"""TODO(nh2tran): docstring.
Weighted cross-entropy loss for a sequence of logits (per example).
Args:
logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].
targets: List of 1D batch-sized int32 Tensors of the same length as logits.
weights: List of 1D batch-sized float-Tensors of the same length as logits.
average_across_timesteps: If set, divide the returned cost by the total
label weight.
softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
name: Optional name for this operation, default: "sequence_loss_by_example".
Returns:
1D batch-sized float Tensor: The log-perplexity for each sequence.
Raises:
ValueError: If len(logits) is different from len(targets) or len(weights).
"""
#~ with tf.name_scope(name="sequence_loss_by_example",
#~ values=logits + targets + weights):
with ops.op_scope(logits + targets + weights,
None,
"sequence_loss_by_example"):
log_perp_list = []
for logit, target, weight in zip(logits, targets, weights):
target = array_ops.reshape(math_ops.to_int64(target), [-1])
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(logits=logit,
labels=target)
log_perp_list.append(crossent * weight)
log_perps = math_ops.add_n(log_perp_list)
# average_across_timesteps:
total_size = math_ops.add_n(weights)
total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
log_perps /= total_size
return log_perps
评论列表
文章目录