def weights_prepend_inputs_to_targets(labels):
"""Assign weight 1.0 to only the "targets" portion of the labels.
Weight 1.0 is assigned to all nonzero labels past the first zero.
See prepend_mode in common_hparams.py
Args:
labels: A Tensor of int32s.
Returns:
A Tensor of floats.
"""
past_first_zero = tf.cumsum(tf.to_float(tf.equal(labels, 0)), axis=1)
nonzero = tf.to_float(labels)
return tf.to_float(tf.not_equal(past_first_zero * nonzero, 0))
评论列表
文章目录