network.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:gc_net_stereo 作者: MaidouPP 项目源码 文件源码
def bn(x, c):
  x_shape = x.get_shape()
  params_shape = x_shape[-1:]

  axis = list(range(len(x_shape) - 1))

  beta = _get_variable('beta',
                       params_shape,
                       initializer=tf.zeros_initializer())
                       #tf.constant_initializer(0.00, dtype='float')
  gamma = _get_variable('gamma',
                        params_shape,
                        initializer=tf.ones_initializer())

  moving_mean = _get_variable('moving_mean',
                              params_shape,
                              initializer=tf.zeros_initializer(),
                              trainable=False)
  moving_variance = _get_variable('moving_variance',
                                  params_shape,
                                  initializer=tf.ones_initializer(),
                                  trainable=False)

  # These ops will only be performed when training.
  mean, variance = tf.nn.moments(x, axis)
  update_moving_mean = moving_averages.assign_moving_average(moving_mean,
                                                             mean, BN_DECAY)
  update_moving_variance = moving_averages.assign_moving_average(
                                        moving_variance, variance, BN_DECAY)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

  mean, variance = control_flow_ops.cond(
    c['is_training'], lambda: (mean, variance),
    lambda: (moving_mean, moving_variance))

  x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)

  return x

# wrapper for get_variable op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号