def sequence_classifier(decoding, labels, sampling_decoding=None, name=None):
"""Returns predictions and loss for sequence of predictions.
Args:
decoding: List of Tensors with predictions.
labels: List of Tensors with labels.
sampling_decoding: Optional, List of Tensor with predictions to be used
in sampling. E.g. they shouldn't have dependncy on outputs.
If not provided, decoding is used.
name: Operation name.
Returns:
Predictions and losses tensors.
"""
with ops.name_scope(name, "sequence_classifier", [decoding, labels]):
predictions, xent_list = [], []
for i, pred in enumerate(decoding):
xent_list.append(nn.softmax_cross_entropy_with_logits(
pred, labels[i],
name="sequence_loss/xent_raw{0}".format(i)))
if sampling_decoding:
predictions.append(nn.softmax(sampling_decoding[i]))
else:
predictions.append(nn.softmax(pred))
xent = math_ops.add_n(xent_list, name="sequence_loss/xent")
loss = math_ops.reduce_sum(xent, name="sequence_loss")
return array_ops_.pack(predictions, axis=1), loss
评论列表
文章目录