def get_mask(batch_size,sequence_length):
lower_triangle=tf.matrix_band_part(tf.ones([sequence_length,sequence_length]),-1,0)
result=-1e9*(1.0-lower_triangle)
print("get_mask==>result:",result)
return result
评论列表
文章目录