def target_mask_op(data, pad_val=0): # HangSheng: return tensor for mask,if input is tf.string
data_shape_size = data.get_shape().ndims
if data_shape_size == 3:
return tf.cast(tf.reduce_any(tf.not_equal(data, pad_val), axis=2), dtype=tf.int32)
elif data_shape_size == 2:
return tf.cast(tf.not_equal(data, pad_val), dtype=tf.int32)
elif data_shape_size == 1:
raise ValueError("target_mask_op: data has wrong shape!")
else:
raise ValueError("target_mask_op: handling data_shape_size %s hasn't been implemented!" % (data_shape_size))
# Dynamic RNN
评论列表
文章目录