common_layers.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def ones_matrix_band_part(rows, cols, num_lower, num_upper, out_shape=None):
  """Matrix band part of ones."""
  if all([isinstance(el, int) for el in [rows, cols, num_lower, num_upper]]):
    # Needed info is constant, so we construct in numpy
    if num_lower < 0:
      num_lower = rows - 1
    if num_upper < 0:
      num_upper = cols - 1
    lower_mask = np.tri(rows, cols, num_lower).T
    upper_mask = np.tri(rows, cols, num_upper)
    band = np.ones((rows, cols)) * lower_mask * upper_mask
    if out_shape:
      band = band.reshape(out_shape)
    band = tf.constant(band, tf.float32)
  else:
    band = tf.matrix_band_part(tf.ones([rows, cols]),
                               tf.cast(num_lower, tf.int64),
                               tf.cast(num_upper, tf.int64))
    if out_shape:
      band = tf.reshape(band, out_shape)

  return band
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号