def get_weights(sequence, eos_id, include_first_eos=True):
cumsum = tf.cumsum(tf.to_float(tf.not_equal(sequence, eos_id)), axis=1)
range_ = tf.range(start=1, limit=tf.shape(sequence)[1] + 1)
range_ = tf.tile(tf.expand_dims(range_, axis=0), [tf.shape(sequence)[0], 1])
weights = tf.to_float(tf.equal(cumsum, tf.to_float(range_)))
if include_first_eos:
weights = weights[:,:-1]
shape = [tf.shape(weights)[0], 1]
weights = tf.concat([tf.ones(tf.stack(shape)), weights], axis=1)
return tf.stop_gradient(weights)
评论列表
文章目录