def get_mask(sequence_length): lower_triangle=tf.matrix_band_part(tf.ones([sequence_length,sequence_length]),-1,0) result=-1e9*(1.0-lower_triangle) print("get_mask==>result:",result) return result