def mask_by_index(batch_size, input_len, max_time_step):
with tf.variable_scope('Masking') as scope:
input_index = tf.range(0, batch_size) * max_time_step + (input_len - 1)
lengths_transposed = tf.expand_dims(input_index, 1)
lengths_tiled = tf.tile(lengths_transposed, [1, max_time_step])
mask_range = tf.range(0, max_time_step)
range_row = tf.expand_dims(mask_range, 0)
range_tiled = tf.tile(range_row, [batch_size, 1])
mask = tf.less_equal(range_tiled, lengths_tiled)
weight = tf.select(mask, tf.ones([batch_size, max_time_step]),
tf.zeros([batch_size, max_time_step]))
weight = tf.reshape(weight, [-1])
return weight
评论列表
文章目录