convnet_builder.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:stuff 作者: yaroslavvb 项目源码 文件源码
def _batch_norm_without_layers(self, input_layer, decay, use_scale, epsilon):
    """Batch normalization on `input_layer` without tf.layers."""
    # We make this function as similar as possible to the
    # tf.contrib.layers.batch_norm, to minimize the differences between using
    # layers and not using layers.
    shape = input_layer.shape
    num_channels = shape[3] if self.data_format == 'NHWC' else shape[1]
    beta = self.get_variable('beta', [num_channels], tf.float32, tf.float32,
                             initializer=tf.zeros_initializer())
    if use_scale:
      gamma = self.get_variable('gamma', [num_channels], tf.float32,
                                tf.float32, initializer=tf.ones_initializer())
    else:
      gamma = tf.constant(1.0, tf.float32, [num_channels])
    # For moving variables, we use tf.get_variable instead of self.get_variable,
    # since self.get_variable returns the result of tf.cast which we cannot
    # assign to.
    moving_mean = tf.get_variable('moving_mean', [num_channels],
                                  tf.float32,
                                  initializer=tf.zeros_initializer(),
                                  trainable=False)
    moving_variance = tf.get_variable('moving_variance', [num_channels],
                                      tf.float32,
                                      initializer=tf.ones_initializer(),
                                      trainable=False)
    if self.phase_train:
      bn, batch_mean, batch_variance = tf.nn.fused_batch_norm(
          input_layer, gamma, beta, epsilon=epsilon,
          data_format=self.data_format, is_training=True)
      mean_update = moving_averages.assign_moving_average(
          moving_mean, batch_mean, decay=decay, zero_debias=False)
      variance_update = moving_averages.assign_moving_average(
          moving_variance, batch_variance, decay=decay, zero_debias=False)
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, mean_update)
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, variance_update)
    else:
      bn, _, _ = tf.nn.fused_batch_norm(
          input_layer, gamma, beta, mean=moving_mean,
          variance=moving_variance, epsilon=epsilon,
          data_format=self.data_format, is_training=False)
    return bn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号