common_layers.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def standardize_images(x):
  """Image standardization on batches (tf.image.per_image_standardization)."""
  with tf.name_scope("standardize_images", [x]):
    x = tf.to_float(x)
    x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keep_dims=True)
    x_variance = tf.reduce_mean(
        tf.square(x - x_mean), axis=[1, 2, 3], keep_dims=True)
    x_shape = shape_list(x)
    num_pixels = tf.to_float(x_shape[1] * x_shape[2] * 3)
    x = (x - x_mean) / tf.maximum(tf.sqrt(x_variance), tf.rsqrt(num_pixels))
    # TODO(lukaszkaiser): remove hack below, needed for greedy decoding for now.
    if x.shape and len(x.shape) == 4 and x.shape[3] == 1:
      x = tf.concat([x, x, x], axis=3)  # Not used, just a dead tf.cond branch.
    x.set_shape([None, None, None, 3])
    return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号