resnet.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:Skeleton-key 作者: feiyu1990 项目源码 文件源码
def _bn(self, x, params_init, is_training):
        x_shape = x.get_shape()
        axis = list(range(len(x_shape) - 1))

        beta = self._get_variable_const('beta', initializer=tf.constant(params_init['bias']))
        gamma = self._get_variable_const('gamma', initializer=tf.constant(params_init['weight']))
        moving_mean = self._get_variable_const('moving_mean',
                                               initializer=tf.constant(params_init['running_mean']), trainable=False)
        moving_variance = self._get_variable_const('moving_variance',
                                                   initializer=tf.constant(params_init['running_var']), trainable=False)
        # mean, variance = tf.nn.moments(x, axis)
        # update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
        # update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
        # tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
        # tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
        #
        # if ~is_training:
        #     mean = moving_mean
        #     variance = moving_variance
        # else:
        #     ema = tf.train.ExponentialMovingAverage(decay=BN_DECAY)
        #
        #     def mean_var_with_update():
        #         ema_apply_op = ema.apply([mean, variance])
        #         with tf.control_dependencies([ema_apply_op]):
        #             return tf.identity(mean), tf.identity(variance)
        #     mean, variance = mean_var_with_update()

        # mean, variance = control_flow_ops.cond(is_training, lambda: (mean, variance),
        #                                        lambda: (moving_mean, moving_variance))
        # x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
        x = tf.layers.batch_normalization(x, momentum=BN_DECAY, epsilon=BN_EPSILON, beta_initializer=tf.constant_initializer(params_init['bias']),
                                          gamma_initializer=tf.constant_initializer(params_init['weight']),
                                          moving_mean_initializer=tf.constant_initializer(params_init['running_mean']),
                                          moving_variance_initializer=tf.constant_initializer(params_init['running_var']),
                                          training=is_training)
        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号