def weights_concatenated(labels):
"""Assign weight 1.0 to the "target" part of the concatenated labels.
The labels look like:
source English I love you . ID1 target French Je t'aime . ID1 source
English the cat ID1 target French le chat ID1 source English ...
We want to assign weight 1.0 to all words in the target text (including the
ID1 end symbol), but not to the source text or the boilerplate. In the
above example, the target words that get positive weight are:
Je t'aime . ID1 le chat ID1
Args:
labels: a Tensor
Returns:
a Tensor
"""
eos_mask = tf.to_int32(tf.equal(labels, 1))
sentence_num = tf.cumsum(eos_mask, axis=1, exclusive=True)
in_target = tf.equal(tf.mod(sentence_num, 2), 1)
# first two tokens of each sentence are boilerplate.
sentence_num_plus_one = sentence_num + 1
shifted = tf.pad(sentence_num_plus_one,
[[0, 0], [2, 0], [0, 0], [0, 0]])[:, :-2, :, :]
nonboilerplate = tf.equal(sentence_num_plus_one, shifted)
ret = tf.to_float(tf.logical_and(nonboilerplate, in_target))
return ret
评论列表
文章目录