def ones_matrix_band_part(rows, cols, num_lower, num_upper, out_shape=None):
"""Matrix band part of ones."""
if all([isinstance(el, int) for el in [rows, cols, num_lower, num_upper]]):
# Needed info is constant, so we construct in numpy
if num_lower < 0:
num_lower = rows - 1
if num_upper < 0:
num_upper = cols - 1
lower_mask = np.tri(rows, cols, num_lower).T
upper_mask = np.tri(rows, cols, num_upper)
band = np.ones((rows, cols)) * lower_mask * upper_mask
if out_shape:
band = band.reshape(out_shape)
band = tf.constant(band, tf.float32)
else:
band = tf.matrix_band_part(tf.ones([rows, cols]),
tf.cast(num_lower, tf.int64),
tf.cast(num_upper, tf.int64))
if out_shape:
band = tf.reshape(band, out_shape)
return band
评论列表
文章目录