network.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:GC-Net 作者: Jiankai-Sun 项目源码 文件源码
def bn(x, c):
  x_shape = x.get_shape()
  params_shape = x_shape[-1:]

  axis = list(range(len(x_shape) - 1))

  beta = tf.get_variable('beta',
                         shape=params_shape,
                         initializer=tf.zeros_initializer(),
             dtype='float32',
             collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
             trainable=True)
  gamma = tf.get_variable('gamma',
                          shape=params_shape,
                          initializer=tf.ones_initializer(),
              dtype='float32',
                         collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
                         trainable=True)

  moving_mean = tf.get_variable('moving_mean',
                              shape=params_shape,
                              initializer=tf.zeros_initializer(),
                  dtype='float32',
                              collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
                              trainable=False)
  moving_variance = tf.get_variable('moving_variance',
                                  shape=params_shape,
                                  initializer=tf.ones_initializer(),
                  dtype='float32',
                                  collections=[tf.GraphKeys.GLOBAL_VARIABLES, GC_VARIABLES],
                                  trainable=False)

  # These ops will only be performed when training.
  mean, variance = tf.nn.moments(x, axis)
  update_moving_mean = moving_averages.assign_moving_average(moving_mean,
                                                             mean, BN_DECAY)
  update_moving_variance = moving_averages.assign_moving_average(
                                        moving_variance, variance, BN_DECAY)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

  mean, variance = control_flow_ops.cond(
    c['is_training'], lambda: (mean, variance),
    lambda: (moving_mean, moving_variance))

  x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)

  return x


# resnet block
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号